Ck tile gemm example (#1488)

* Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout

* Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future.

* Fix: Clang Format, API fixed from fmha

* fix with better naming convention

* revert back the pipeline code of fmha

* Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one.

* clang format with the reference_gemm file

* convert the clang format with the remod.py

* Changed the format and variable name of the kernel gemm_shape and partitioner

---------

Co-authored-by: thomasning <thomasning@banff-cyxtera-s70-4.ctr.dcgpu>
This commit is contained in:
Thomas Ning
2024-09-07 01:23:32 -07:00
committed by GitHub
parent 8378855361
commit caacd38830
18 changed files with 758 additions and 92 deletions

View File

@@ -13,6 +13,9 @@ template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
@@ -24,7 +27,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const ACCElementOp& acc_element_op = {})
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0];
const int M = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[0]
: a_m_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
@@ -33,7 +41,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_element_op(a_m_k(m, k));
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k));
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
@@ -44,7 +54,6 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
}
};
make_ParallelTensorFunctor(f,
c_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
}
} // namespace ck_tile