mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Ck tile gemm example (#1488)
* Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout * Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future. * Fix: Clang Format, API fixed from fmha * fix with better naming convention * revert back the pipeline code of fmha * Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one. * clang format with the reference_gemm file * convert the clang format with the remod.py * Changed the format and variable name of the kernel gemm_shape and partitioner --------- Co-authored-by: thomasning <thomasning@banff-cyxtera-s70-4.ctr.dcgpu>
This commit is contained in:
@@ -4,7 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
|
||||
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
|
||||
using BlockGemmARegBGmemCRegImpl = BlockGemmARegBGmemCRegV1<
|
||||
BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
|
||||
BlockGemmARegBSmemCRegV1DefaultPolicy>;
|
||||
BlockGemmARegBGmemCRegV1DefaultPolicy>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
@@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1
|
||||
block_sync_lds();
|
||||
|
||||
// block GEMM
|
||||
BlockGemmARegBSmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
|
||||
BlockGemmARegBGmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
@@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1
|
||||
block_sync_lds();
|
||||
|
||||
// block GEMM
|
||||
return BlockGemmARegBSmemCRegImpl{}(a_block_tensor, b_block_smem_window);
|
||||
return BlockGemmARegBGmemCRegImpl{}(a_block_tensor, b_block_smem_window);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user