mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
Remove explicit type in invoke
This commit is contained in:
@@ -16,21 +16,21 @@ template <typename ADataType,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_multi_d_gemm(const void* a_m_k_dev_buf,
|
||||
const void* b_k_n_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
|
||||
const void* ds_m_n_dev_buf,
|
||||
void* c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t StrideA,
|
||||
ck_tile::index_t StrideB,
|
||||
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
|
||||
const ck_tile::index_t* StrideDs,
|
||||
ck_tile::index_t StrideC,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
multiple_d_gemm_kargs gemm_descs({a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
ds_m_n_dev_buf.data(),
|
||||
ds_m_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
/*kbatch */ 1,
|
||||
M,
|
||||
@@ -38,7 +38,7 @@ float invoke_multi_d_gemm(const void* a_m_k_dev_buf,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs.data(),
|
||||
StrideDs,
|
||||
StrideC});
|
||||
|
||||
float ave_time = multiple_d_gemm<ADataType,
|
||||
@@ -165,14 +165,14 @@ int run_multiple_d_gemm_example_with_layouts(int argc,
|
||||
CLayout,
|
||||
CDEElementWiseFn>(a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
ds_ptr_buf.data(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
stridesDs,
|
||||
stridesDs.data(),
|
||||
StrideC,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
@@ -90,7 +90,6 @@ struct GemmKernel
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using Empty_Tuple = ck_tile::tuple<>;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
|
||||
Reference in New Issue
Block a user