mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)
* WMMA GEMM F16 Implementation Signed-off-by: root <tianyuwu@amd.com> * Self-review Signed-off-by: root <tianyuwu@amd.com> * ASIC check minor tweak Signed-off-by: root <tianyuwu@amd.com> * add missing include file * Set GPU_TARGETS to gfx11/12 generic Signed-off-by: root <tianyuwu@amd.com> * INT8 GFX12 Signed-off-by: root <tianyuwu@amd.com> * add int8x16 branch * Fix CI script Signed-off-by: root <tianyuwu@amd.com> * Fix typo Signed-off-by: root <tianyuwu@amd.com> * Add CK_Tile WMMA example Signed-off-by: Tianyuan Wu <tianyuwu@amd.com> * Fix CI Signed-off-by: Tianyuan Wu <tianyuwu@amd.com> * fix clang format * Set M/N_Warp Back to Constant Signed-off-by: Tianyuan Wu <tianyuwu@amd.com> * Use GemmConfigComputeV3 by default Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12 Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Remove CK_Tile wmma gemm examples from the CI list Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Add atomic add fallback method for gfx11 Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Fix typo Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Omit copyright year Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Support non-square cases Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Fix CI Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Add get_device_ip() Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Revert "Add atomic add fallback method for gfx11" This reverts commit07a79e797d. Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> * Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12" This reverts commitceee918007. * Revise method name and typos Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> * clang-format Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Try fix CI Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Revert "Try fix CI" This reverts commit7a7241085e. * clang-format Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Fix typo caused by merge Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> * Fix typo caused by merging Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> --------- Signed-off-by: root <tianyuwu@amd.com> Signed-off-by: Tianyuan Wu <tianyuwu@amd.com> Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -69,7 +69,7 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; }
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
@@ -80,32 +80,30 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<13, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>{};
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>{};
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>{};
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>{};
|
||||
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
static constexpr bool Persistent =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
|
||||
// TODO: expose tile size through test t-param ?
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
@@ -247,11 +245,48 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type()
|
||||
{
|
||||
return static_cast<Derived*>(this)
|
||||
->template check_data_type_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>();
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
ck_tile::index_t M_Warp_Tile,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile>
|
||||
bool check_data_type_impl()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if(!check_data_type<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile>())
|
||||
{
|
||||
GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test.";
|
||||
}
|
||||
if constexpr(PipelineType == GemmPipelineType::CompV4)
|
||||
{
|
||||
// Only do k_batch = 1 when pipeline is CompV4
|
||||
|
||||
Reference in New Issue
Block a user