mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Ck tile gemm cshuffle & CK Tile GEMM restructure (#1535)
* ake the cshuffle compilable * modify Mhe reference on gpu and cpu. Correaccess of cshuffle * fix the cpu reference code * Complete the in tile shuffle logic * restructure the kernel template input * change the naming pattern of ck_tile gemm pipeline * Re-format files using remod.py * Solve the fmha conflict with gemm * Comment Addressed from Carlus --------- Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = b_n_k.mDesc.get_lengths()[0];
|
||||
const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? b_n_k.mDesc.get_lengths()[0]
|
||||
: b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_m_k.mDesc.get_lengths()[1]
|
||||
: a_m_k.mDesc.get_lengths()[0];
|
||||
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_element_op(a_m_k(m, k))
|
||||
: a_element_op(a_m_k(k, m));
|
||||
BDataType v_b = b_element_op(b_n_k(n, k));
|
||||
BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? b_element_op(b_n_k(n, k))
|
||||
: b_element_op(b_n_k(k, n));
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? c_m_n(m, n)
|
||||
: c_m_n(n, m);
|
||||
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void naive_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
CDataType* C,
|
||||
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
if(row < M && col < N)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
acc += static_cast<AccDataType>(A[row * strideA + k]) *
|
||||
static_cast<AccDataType>(B[col * strideB + k]);
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
: k * strideA + row;
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
|
||||
}
|
||||
|
||||
C[row * strideC + col] = acc; // Store as AccDataType
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_gemm_gpu(DeviceMem& a_device,
|
||||
DeviceMem& b_device,
|
||||
DeviceMem& c_device,
|
||||
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType>
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
|
||||
errC = hipMemcpy(
|
||||
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
|
||||
|
||||
Reference in New Issue
Block a user