mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Merge commit 'a44bea45b205a84552e417a7b069d962d73c6cb1' into develop
This commit is contained in:
@@ -88,10 +88,10 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -333,8 +333,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
{/*ds_ptr*/},
|
||||
p_c,
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
{/*stride_Ds*/},
|
||||
stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
|
||||
Reference in New Issue
Block a user