Remove explicit type in invoke

This commit is contained in:
Mateusz Ozga
2025-03-22 18:46:54 +00:00
parent 8ce06348b5
commit bbcc011b36
2 changed files with 6 additions and 7 deletions

View File

@@ -16,21 +16,21 @@ template <typename ADataType,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_multi_d_gemm(const void* a_m_k_dev_buf,
const void* b_k_n_dev_buf,
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
const void* ds_m_n_dev_buf,
void* c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t StrideA,
ck_tile::index_t StrideB,
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
const ck_tile::index_t* StrideDs,
ck_tile::index_t StrideC,
int n_warmup,
int n_repeat)
{
multiple_d_gemm_kargs gemm_descs({a_m_k_dev_buf,
b_k_n_dev_buf,
ds_m_n_dev_buf.data(),
ds_m_n_dev_buf,
c_m_n_dev_buf,
/*kbatch */ 1,
M,
@@ -38,7 +38,7 @@ float invoke_multi_d_gemm(const void* a_m_k_dev_buf,
K,
StrideA,
StrideB,
StrideDs.data(),
StrideDs,
StrideC});
float ave_time = multiple_d_gemm<ADataType,
@@ -165,14 +165,14 @@ int run_multiple_d_gemm_example_with_layouts(int argc,
CLayout,
CDEElementWiseFn>(a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
ds_ptr_buf.data(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
stridesDs,
stridesDs.data(),
StrideC,
n_warmup,
n_repeat);

View File

@@ -90,7 +90,6 @@ struct GemmKernel
// Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using Empty_Tuple = ck_tile::tuple<>;
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr auto I0 = number<0>();