mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)
* WMMA GEMM F16 Implementation
Signed-off-by: root <tianyuwu@amd.com>
* Self-review
Signed-off-by: root <tianyuwu@amd.com>
* ASIC check minor tweak
Signed-off-by: root <tianyuwu@amd.com>
* add missing include file
* Set GPU_TARGETS to gfx11/12 generic
Signed-off-by: root <tianyuwu@amd.com>
* INT8 GFX12
Signed-off-by: root <tianyuwu@amd.com>
* add int8x16 branch
* Fix CI script
Signed-off-by: root <tianyuwu@amd.com>
* Fix typo
Signed-off-by: root <tianyuwu@amd.com>
* Add CK_Tile WMMA example
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* Fix CI
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* fix clang format
* Set M/N_Warp Back to Constant
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
* Use GemmConfigComputeV3 by default
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Remove CK_Tile wmma gemm examples from the CI list
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Add atomic add fallback method for gfx11
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix typo
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Omit copyright year
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Support non-square cases
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix CI
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Add get_device_ip()
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Revert "Add atomic add fallback method for gfx11"
This reverts commit 4f664969c01b37976c8518c19833d9f1574cd746.
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12"
This reverts commit 949129a3858a825b2a2c4d3ec01663df18a165a5.
* Revise method name and typos
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* clang-format
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Try fix CI
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Revert "Try fix CI"
This reverts commit 084c683227e64ab6a8137db00c8165fb05bdc902.
* clang-format
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
* Fix typo caused by merge
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
* Fix typo caused by merging
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
---------
Signed-off-by: root <tianyuwu@amd.com>
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
[ROCm/composable_kernel commit: 68134b60e4]
This commit is contained in:
@@ -327,6 +327,7 @@ endif()
|
||||
if(USE_OPT_GFX11)
|
||||
add_compile_options(-mcumode)
|
||||
add_compile_options(-mno-wavefrontsize64)
|
||||
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
|
||||
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
|
||||
endif()
|
||||
|
||||
@@ -336,6 +337,12 @@ if(ENABLE_ASM_DUMP)
|
||||
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
|
||||
endif()
|
||||
|
||||
if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12"))
|
||||
add_compile_options(-mno-wavefrontsize64)
|
||||
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
|
||||
message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}")
|
||||
endif()
|
||||
|
||||
## Threads
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
4
Jenkinsfile
vendored
4
Jenkinsfile
vendored
@@ -1474,7 +1474,7 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx1101") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx11-generic" \
|
||||
@@ -1495,7 +1495,7 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx1201") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DUSE_OPT_GFX12=ON -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx12-generic" \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -172,6 +172,27 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
|
||||
@@ -346,5 +346,4 @@ int main(int argc, char* argv[])
|
||||
// Return a non-zero code to indicate failure
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
0
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Executable file → Normal file
0
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Executable file → Normal file
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
|
||||
#define CK_TILE_VMCNT(cnt) \
|
||||
@@ -59,7 +60,7 @@ enum struct memory_operation_enum : std::uint16_t
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
#if defined(__GFX9__) || (!defined(__HIP_DEVICE_COMPILE__) && !defined(CK_TILE_WAVE32_ENABLED))
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
@@ -230,4 +231,20 @@ CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_
|
||||
}
|
||||
}
|
||||
|
||||
// Architecture tags
|
||||
struct gfx11_t
|
||||
{
|
||||
};
|
||||
struct gfx12_t
|
||||
{
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_device_arch()
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
return gfx11_t{};
|
||||
#else // if defined(__gfx12__)
|
||||
return gfx12_t{};
|
||||
#endif
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,10 @@
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, typename ComputeType>
|
||||
@@ -32,6 +36,14 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
|
||||
{
|
||||
fp16x2_t rtn;
|
||||
rtn[0] = add<fp16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
|
||||
{
|
||||
fp8x4_t rtn;
|
||||
@@ -304,6 +316,44 @@ CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp16x2_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
|
||||
{
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
|
||||
#else
|
||||
union U32F162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp16x2_t* f162_a;
|
||||
};
|
||||
|
||||
union U32F162
|
||||
{
|
||||
uint32_t u32;
|
||||
fp16x2_t f162;
|
||||
};
|
||||
|
||||
U32F162_ADDR dword_addr;
|
||||
U32F162 cur_v;
|
||||
U32F162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.f162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.f162 = add_f16x2_t(cur_v.f162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
@@ -311,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
|
||||
@@ -406,6 +457,13 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp16_t>::value)
|
||||
{
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
|
||||
x.template get_as<fp16x2_t>()[i]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
|
||||
@@ -152,7 +152,7 @@
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
@@ -274,6 +274,12 @@
|
||||
#define CK_TILE_WA_ISSUE_2028 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_WAVE32_ENABLED
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
#define CK_TILE_WAVE32_ENABLED
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Y pointed to R, we don't see a valuable use case.
|
||||
// Will enforce encoding to check Y not pointed to R if set to zero
|
||||
#ifndef CK_TILE_ENC_SUPPORT_Y_TO_R
|
||||
|
||||
@@ -52,6 +52,19 @@ inline std::string get_device_name()
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_gfx11_supported()
|
||||
{
|
||||
return get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
|
||||
get_device_name() == "gfx1102" || get_device_name() == "gfx1103" ||
|
||||
get_device_name() == "gfx1150" || get_device_name() == "gfx1151" ||
|
||||
get_device_name() == "gfx1152";
|
||||
}
|
||||
|
||||
inline bool is_gfx12_supported()
|
||||
{
|
||||
return get_device_name() == "gfx1200" || get_device_name() == "gfx1201";
|
||||
}
|
||||
|
||||
inline bool is_load_tr_supported()
|
||||
{
|
||||
// Check if load transpose is supported.
|
||||
|
||||
@@ -203,13 +203,13 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
using WG = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
@@ -130,13 +130,13 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
using WG = WarpGemmDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
|
||||
@@ -430,13 +430,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
|
||||
typename Problem::ADataType,
|
||||
|
||||
@@ -43,7 +43,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
using WarpGemm = WarpGemmDispatcher<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -78,18 +78,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
WarpGemmDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
@@ -115,7 +115,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
using WarpGemm = WarpGemmDispatcher<
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -150,18 +150,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
WarpGemmDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
@@ -187,14 +187,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
|
||||
false>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
|
||||
false>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
|
||||
@@ -25,7 +25,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto SwizzleA = false;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -66,7 +66,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
constexpr auto SwizzleA = false;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -106,7 +106,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -512,14 +512,13 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
|
||||
@@ -546,22 +545,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single>;
|
||||
using WarpGemm =
|
||||
WarpGemmDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
|
||||
|
||||
@@ -956,20 +956,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
{
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
|
||||
// return
|
||||
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
// Problem::PDataType, typename Problem::VDataType>>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return WarpGemmMfmaDispatcher<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
return WarpGemmDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
return WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
@@ -576,7 +576,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
return WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
@@ -695,7 +695,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
return WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
@@ -703,7 +703,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
return WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
|
||||
@@ -58,9 +58,15 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -54,16 +54,16 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16<>{}, 2, 2);
|
||||
}
|
||||
#else
|
||||
using WG = WarpGemmMfmaDispatcher<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
using WG = WarpGemmDispatcher<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
@@ -71,16 +71,16 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
using WG = WarpGemmMfmaDispatcher<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
using WG = WarpGemmDispatcher<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -182,7 +182,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
@@ -242,7 +242,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
|
||||
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
|
||||
@@ -32,16 +32,17 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -21,15 +21,16 @@ struct GemmPipelineAgBgCrCompV5DefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -390,16 +390,17 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -635,16 +635,17 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -280,13 +280,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockWeightPreshufflePolicy =
|
||||
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
|
||||
@@ -15,19 +15,19 @@ namespace ck_tile {
|
||||
// fp16
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
@@ -36,42 +36,42 @@ using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterate
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
@@ -80,13 +80,13 @@ using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
@@ -94,36 +94,36 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
#else
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M4N64K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
@@ -136,19 +136,19 @@ using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl<WarpGemmAttributeSmf
|
||||
|
||||
// bf16
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
@@ -157,43 +157,43 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaItera
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
@@ -202,153 +202,153 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
#else
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16<WGAttrCtlEnum::Default_>>>;
|
||||
#else
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
// fp8
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <index_t swizzle_factor = 2>
|
||||
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
|
||||
// int8
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -19,7 +19,7 @@ enum class WGAttrNumAccessEnum
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfma
|
||||
struct WarpGemmAttributeMfma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
@@ -103,7 +103,7 @@ struct WarpGemmAtrributeMfma
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaIterateK
|
||||
struct WarpGemmAttributeMfmaIterateK
|
||||
{
|
||||
static_assert(kKIter > 0, "wrong!");
|
||||
|
||||
@@ -367,7 +367,7 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
struct WarpGemmAttributeMfmaTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
@@ -450,7 +450,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
|
||||
struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
@@ -546,7 +546,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
template <typename WarpGemmAttributeMfmaImpl_,
|
||||
index_t kKIter,
|
||||
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
static constexpr auto AttrNumAccess = AttrNumAccess_;
|
||||
@@ -574,13 +574,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
return WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
get_bwarp_dstr_encoding();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
return WarpGemmAtrributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
return WarpGemmAttributeMfmaIterateK<Impl, kKIter, AttrNumAccess>::
|
||||
get_awarp_dstr_encoding();
|
||||
}
|
||||
|
||||
@@ -696,7 +696,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
|
||||
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
@@ -840,7 +840,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
|
||||
struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
struct WarpGemmAttributeMfmaIterateK_SwizzleA
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
|
||||
147
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp
Normal file
147
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp
Normal file
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
|
||||
// to determine the layout in the future
|
||||
template <typename Impl>
|
||||
struct AWarpDstrEncodingTrait
|
||||
{
|
||||
using type = tile_distribution_encoding<
|
||||
sequence<Impl::kRepeat>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane>>,
|
||||
tuple<typename Impl::kABPs2RHssMajor>,
|
||||
tuple<typename Impl::kABPs2RHssMinor>,
|
||||
typename Impl::kABYs2RHsMajor,
|
||||
typename Impl::kABYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename Impl>
|
||||
struct BWarpDstrEncodingTrait
|
||||
{
|
||||
using type = tile_distribution_encoding<
|
||||
sequence<Impl::kRepeat>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABK0PerLane, Impl::kABKLane, Impl::kABK1PerLane>>,
|
||||
tuple<typename Impl::kABPs2RHssMajor>,
|
||||
tuple<typename Impl::kABPs2RHssMinor>,
|
||||
typename Impl::kABYs2RHsMajor,
|
||||
typename Impl::kABYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename Impl>
|
||||
struct CWarpDstrEncodingTrait
|
||||
{
|
||||
using type = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>,
|
||||
sequence<Impl::kCNLane>>,
|
||||
tuple<typename Impl::kCPs2RHssMajor>,
|
||||
tuple<typename Impl::kCPs2RHssMinor>,
|
||||
typename Impl::kCYs2RHsMajor,
|
||||
typename Impl::kCYs2RHsMinor>;
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
|
||||
struct WarpGemmAttributeWmma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
// 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
|
||||
// 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
|
||||
using AWarpDstrEncoding = typename AWarpDstrEncodingTrait<Impl>::type;
|
||||
using BWarpDstrEncoding = typename BWarpDstrEncodingTrait<Impl>::type;
|
||||
|
||||
// kCM0PerLane = 4, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 for 16 bit input
|
||||
// kCM0PerLane = 2, kCMLane = 2, kCM1PerLane = 4, kCNLane = 16 for 8 bit input
|
||||
using CWarpDstrEncoding = typename CWarpDstrEncodingTrait<Impl>::type;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
return Impl{}(b_vec, a_vec);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Impl{}(a_vec, b_vec);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
index_t M_Warp_Tile,
|
||||
index_t N_Warp_Tile,
|
||||
index_t K_Warp_Tile>
|
||||
CK_TILE_HOST bool check_wmma_supported()
|
||||
{
|
||||
if(is_gfx12_supported())
|
||||
{
|
||||
return has_wmma_traits_v<gfx12_t,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>;
|
||||
}
|
||||
else if(is_gfx11_supported())
|
||||
{
|
||||
return has_wmma_traits_v<gfx11_t,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
132
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp
Normal file
132
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp
Normal file
@@ -0,0 +1,132 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Base traits for WMMA operations
|
||||
template <typename Arch,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K>
|
||||
struct WmmaTraits;
|
||||
|
||||
// Generic WMMA implementation using traits
|
||||
template <typename Traits>
|
||||
struct WarpGemmAttributeWmmaImpl
|
||||
{
|
||||
using ADataType = typename Traits::ADataType;
|
||||
using BDataType = typename Traits::BDataType;
|
||||
using CDataType = typename Traits::CDataType;
|
||||
|
||||
using AVecType = typename Traits::AVecType;
|
||||
using BVecType = typename Traits::BVecType;
|
||||
using CVecType = typename Traits::CVecType;
|
||||
|
||||
// Forward all static constants and type aliases
|
||||
static constexpr index_t kM = Traits::kM;
|
||||
static constexpr index_t kN = Traits::kN;
|
||||
static constexpr index_t kK = Traits::kK;
|
||||
|
||||
static constexpr index_t kRepeat = Traits::kRepeat;
|
||||
static constexpr index_t kAMLane = Traits::kAMLane;
|
||||
static constexpr index_t kBNLane = Traits::kBNLane;
|
||||
static constexpr index_t kABK0PerLane = Traits::kABK0PerLane;
|
||||
static constexpr index_t kABKLane = Traits::kABKLane;
|
||||
static constexpr index_t kABK1PerLane = Traits::kABK1PerLane;
|
||||
|
||||
static constexpr index_t kCMLane = Traits::kCMLane;
|
||||
static constexpr index_t kCNLane = Traits::kCNLane;
|
||||
static constexpr index_t kCM0PerLane = Traits::kCM0PerLane;
|
||||
static constexpr index_t kCM1PerLane = Traits::kCM1PerLane;
|
||||
|
||||
using kABPs2RHssMajor = typename Traits::kABPs2RHssMajor;
|
||||
using kABPs2RHssMinor = typename Traits::kABPs2RHssMinor;
|
||||
using kABYs2RHsMajor = typename Traits::kABYs2RHsMajor;
|
||||
using kABYs2RHsMinor = typename Traits::kABYs2RHsMinor;
|
||||
|
||||
using kCPs2RHssMajor = typename Traits::kCPs2RHssMajor;
|
||||
using kCPs2RHssMinor = typename Traits::kCPs2RHssMinor;
|
||||
using kCYs2RHsMajor = typename Traits::kCYs2RHsMajor;
|
||||
using kCYs2RHsMinor = typename Traits::kCYs2RHsMinor;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool clamp = false, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
c_vec = Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, c_vec);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return bit_cast<CVecType>(
|
||||
Traits::template wmma_intrinsic<clamp>(a_vec, b_vec, CVecType{0.f}));
|
||||
}
|
||||
};
|
||||
|
||||
using DeviceIp = remove_cvref_t<decltype(ck_tile::get_device_arch())>;
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, fp16_t, fp16_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, bf16_t, bf16_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<DeviceIp, int8_t, int8_t, int32_t, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>>;
|
||||
|
||||
template <typename Arch,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t warp_m,
|
||||
index_t warp_n,
|
||||
index_t warp_k>
|
||||
struct has_wmma_traits
|
||||
{
|
||||
template <typename T>
|
||||
static auto
|
||||
test(int) -> decltype(std::declval<
|
||||
typename WmmaTraits<T, AType, BType, CType, warp_m, warp_n, warp_k>::
|
||||
ADataType>(),
|
||||
std::true_type{});
|
||||
|
||||
template <typename>
|
||||
static std::false_type test(...);
|
||||
|
||||
static constexpr bool value = decltype(test<Arch>(0))::value;
|
||||
};
|
||||
|
||||
template <typename Arch,
|
||||
typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t warp_m,
|
||||
index_t warp_n,
|
||||
index_t warp_k>
|
||||
constexpr bool has_wmma_traits_v =
|
||||
has_wmma_traits<Arch, AType, BType, CType, warp_m, warp_n, warp_k>::value;
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
|
||||
namespace ck_tile {
|
||||
// fp16 specialization - GFX11
|
||||
template <>
|
||||
struct WmmaTraits<gfx11_t, fp16_t, fp16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, fp16_t, fp16_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx11__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// bf16 specialization - GFX11
|
||||
template <>
|
||||
struct WmmaTraits<gfx11_t, bf16_t, bf16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, bf16_t, bf16_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx11__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// fp16 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, fp16_t, fp16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp16_t, fp16_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// bf16 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_vec, b_vec, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,138 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
|
||||
namespace ck_tile {
|
||||
// int8 specialization - GFX11
|
||||
template <>
|
||||
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx11_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx11__
|
||||
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // neg_a
|
||||
bit_cast<int32x4_t>(a_vec),
|
||||
true, // neg_b
|
||||
bit_cast<int32x4_t>(b_vec),
|
||||
bit_cast<int32x8_t>(c_vec),
|
||||
clamp);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// int8 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, int8_t, int8_t, int32_t>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // neg_a
|
||||
bit_cast<int32x2_t>(a_vec),
|
||||
true, // neg_b
|
||||
bit_cast<int32x2_t>(b_vec),
|
||||
bit_cast<int32x8_t>(c_vec),
|
||||
clamp);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// fp8/bf8 specialization - GFX12
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, fp8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, bf8_t, bf8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, fp8_t, bf8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx12_t, bf8_t, fp8_t, float, 16, 16, 16>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, fp8_t, float>
|
||||
{
|
||||
template <bool clamp = false>
|
||||
CK_TILE_DEVICE static CVecType
|
||||
wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
|
||||
bit_cast<int32x2_t>(a_vec), bit_cast<int32x2_t>(b_vec), bit_cast<fp32x8_t>(c_vec));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
namespace ck_tile {
|
||||
template <typename Arch, typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase;
|
||||
|
||||
// GFX11 specialization
|
||||
template <typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase<gfx11_t, ADType, BDType, CDType>
|
||||
{
|
||||
using ADataType = ADType;
|
||||
using BDataType = BDType;
|
||||
using CDataType = CDType;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 8>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kRepeat = 2;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 1;
|
||||
static constexpr index_t kABK1PerLane = 16;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 8;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
using kABPs2RHssMajor = sequence<0, 2, 1>;
|
||||
using kABPs2RHssMinor = sequence<0, 1, 0>;
|
||||
using kABYs2RHsMajor = sequence<2, 2>;
|
||||
using kABYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
};
|
||||
|
||||
// GFX12 specialization
|
||||
template <typename ADType, typename BDType, typename CDType>
|
||||
struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
|
||||
{
|
||||
using ADataType = ADType;
|
||||
using BDataType = BDType;
|
||||
using CDataType = CDType;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
using CVecType = ext_vector_t<CDataType, 8>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kRepeat = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 2;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABK1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 8;
|
||||
|
||||
using kABPs2RHssMajor = sequence<2, 1>;
|
||||
using kABPs2RHssMinor = sequence<1, 0>;
|
||||
using kABYs2RHsMajor = sequence<2, 2>;
|
||||
using kABYs2RHsMinor = sequence<0, 2>;
|
||||
|
||||
using kCPs2RHssMajor = sequence<1, 2>;
|
||||
using kCPs2RHssMinor = sequence<1, 0>;
|
||||
using kCYs2RHsMajor = sequence<1, 1>;
|
||||
using kCYs2RHsMinor = sequence<0, 2>;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -19,115 +20,133 @@ template <typename AType,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
struct WarpGemmDispatcher;
|
||||
|
||||
// clang-format off
|
||||
// fp16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
|
||||
// WMMA cases
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f16_f16<TransposeC>;};
|
||||
#else
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
#endif
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
|
||||
// bf16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
|
||||
// WMMA cases
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16<TransposeC>; };
|
||||
#else
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
#endif
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
//WMMA cases
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_f8_bf8<TransposeC>; };
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_bf8_f8<TransposeC>; };
|
||||
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
// WMMA cases
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8<TransposeC>;};
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
@@ -142,15 +161,15 @@ template <typename AType,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
using WarpGemmDispatcher = typename impl::WarpGemmDispatcher<AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
37
include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp
Normal file
37
include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_f16_f16 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f16_f16, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_bf16_bf16 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf16_bf16, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_i32_16x16x16_i8_i8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_i32_16x16x16_i8_i8, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_f8_f8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_f8, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_bf8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_bf8, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_f8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_f8_bf8, kTransC>>;
|
||||
|
||||
template <bool kTransC = false>
|
||||
using WarpGemmWmma_f32_16x16x16_bf8_f8 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x16_bf8_f8, kTransC>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -44,13 +44,13 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
false>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
false>;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(Preshuffle)
|
||||
@@ -92,13 +92,13 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
false>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
false>;
|
||||
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<typename Problem::CDataType, float>);
|
||||
|
||||
@@ -30,6 +30,14 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
# On Radeon devices, build the WMMA version instead
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
@@ -46,4 +54,7 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MAT
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline<T>
|
||||
class TestCkTileGemmPipelineCompV3
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV3<T>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp
Normal file
17
test/ck_tile/gemm/test_gemm_pipeline_compv3_wmma.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_wmma_base.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV3Wmma
|
||||
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV3Wmma<T>>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3Wmma
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3Wmma, KernelTypesCompV3Wmma);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -3,7 +3,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline<T>
|
||||
class TestCkTileGemmPipelineCompV4
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV4<T>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
test/ck_tile/gemm/test_gemm_pipeline_compv4_wmma.cpp
Normal file
17
test/ck_tile/gemm/test_gemm_pipeline_compv4_wmma.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_wmma_base.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV4Wmma
|
||||
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineCompV4Wmma<T>>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4Wmma
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4Wmma, KernelTypesCompV4Wmma);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -9,13 +9,16 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
|
||||
using I8 = ck_tile::int8_t;
|
||||
using I32 = ck_tile::int32_t;
|
||||
using INT8 = ck_tile::int8_t;
|
||||
using INT32 = ck_tile::int32_t;
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
@@ -30,52 +33,119 @@ using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Co
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I64 = ck_tile::number<64>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem>
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Interwave, Mem>
|
||||
>;
|
||||
|
||||
using KernelTypesMemWmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Interwave, Mem>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, I8, I8, I32, I32, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, I8, I8, I32, I32, Intrawave, CompV3>
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3Wmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4Wmma = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
|
||||
using KernelTypesPersistent = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent>
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, M_TileSize, K_TileSize, Scheduler, PipelineType, Persistent
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3, NonPersistent>
|
||||
>;
|
||||
|
||||
using KernelTypesPersistentWmma = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, NonPersistent>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T>
|
||||
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineMem<T>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
test/ck_tile/gemm/test_gemm_pipeline_mem_wmma.cpp
Normal file
17
test/ck_tile/gemm/test_gemm_pipeline_mem_wmma.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_wmma_base.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineMemWmma
|
||||
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelineMemWmma<T>>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineMemWmma
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineMemWmma, KernelTypesMemWmma);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -3,7 +3,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline<T>
|
||||
class TestCkTileGemmPipelinePersistent
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelinePersistent<T>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
test/ck_tile/gemm/test_gemm_pipeline_persistent_wmma.cpp
Normal file
17
test/ck_tile/gemm/test_gemm_pipeline_persistent_wmma.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_wmma_base.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelinePersistentWmma
|
||||
: public TestCkTileGemmPipelineWmmaBase<T, TestCkTileGemmPipelinePersistentWmma<T>>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistentWmma
|
||||
|
||||
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistentWmma);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -69,7 +69,7 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; }
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
@@ -80,32 +80,30 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<13, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>{};
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>{};
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>{};
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>{};
|
||||
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
static constexpr bool Persistent =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
|
||||
// TODO: expose tile size through test t-param ?
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
@@ -247,11 +245,48 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type()
|
||||
{
|
||||
return static_cast<Derived*>(this)
|
||||
->template check_data_type_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>();
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type_impl()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if(!check_data_type<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>())
|
||||
{
|
||||
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
|
||||
}
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4)
|
||||
{
|
||||
// Only do k_batch = 1 when pipeline is CompV4
|
||||
|
||||
24
test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp
Normal file
24
test/ck_tile/gemm/test_gemm_pipeline_wmma_base.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmPipelineWmmaBase : public TestCkTileGemmPipeline<Tuple, Derived>
|
||||
{
|
||||
public:
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type_impl()
|
||||
{
|
||||
return ck_tile::check_wmma_supported<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>();
|
||||
}
|
||||
};
|
||||
0
test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc
Executable file → Normal file
0
test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc
Executable file → Normal file
Reference in New Issue
Block a user