diff --git a/example/ck_tile/18_multi_d_gemm/run_multi_d_gemm_example.inc b/example/ck_tile/18_multi_d_gemm/run_multi_d_gemm_example.inc index 4868f34ce6..4511a0b895 100644 --- a/example/ck_tile/18_multi_d_gemm/run_multi_d_gemm_example.inc +++ b/example/ck_tile/18_multi_d_gemm/run_multi_d_gemm_example.inc @@ -16,21 +16,21 @@ template float invoke_multi_d_gemm(const void* a_m_k_dev_buf, const void* b_k_n_dev_buf, - const std::array& ds_m_n_dev_buf, + const void* ds_m_n_dev_buf, void* c_m_n_dev_buf, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t StrideA, ck_tile::index_t StrideB, - const std::array& StrideDs, + const ck_tile::index_t* StrideDs, ck_tile::index_t StrideC, int n_warmup, int n_repeat) { multiple_d_gemm_kargs gemm_descs({a_m_k_dev_buf, b_k_n_dev_buf, - ds_m_n_dev_buf.data(), + ds_m_n_dev_buf, c_m_n_dev_buf, /*kbatch */ 1, M, @@ -38,7 +38,7 @@ float invoke_multi_d_gemm(const void* a_m_k_dev_buf, K, StrideA, StrideB, - StrideDs.data(), + StrideDs, StrideC}); float ave_time = multiple_d_gemm(a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), - ds_ptr_buf, + ds_ptr_buf.data(), c_m_n_dev_buf.GetDeviceBuffer(), M, N, K, StrideA, StrideB, - stridesDs, + stridesDs.data(), StrideC, n_warmup, n_repeat); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 9b118783f6..a817f191de 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -90,7 +90,6 @@ struct GemmKernel // Below type is actually accumulation data type - the output of block GEMM. using CDataType = remove_cvref_t; - using Empty_Tuple = ck_tile::tuple<>; static constexpr index_t NumDTensor = DsDataType::size(); static constexpr auto I0 = number<0>();