From bfdfd1e14855a6661453fa58eb2f2d8f8b519a6f Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Mon, 28 Jul 2025 08:11:15 +0000 Subject: [PATCH] fix reference --- .../run_grouped_flatmm_example.inc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc index 955769dab0..b72e687ea0 100644 --- a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc +++ b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc @@ -597,7 +597,7 @@ int run_masked_grouped_flatmm_example_with_layouts( } } - ck_tile::index_t M = 4096;//Ms[0]; + ck_tile::index_t M = 4096; // Ms[0]; ck_tile::index_t N = Ns[0]; ck_tile::index_t K = Ks[0]; @@ -683,12 +683,11 @@ int run_masked_grouped_flatmm_example_with_layouts( BDataType* d_B; CDataType* d_C; ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); - ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); - ck_tile::hip_check_error(hipMemset(d_C, 0, M * N * sizeof(CDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, group_count * M * N * sizeof(CDataType))); + ck_tile::hip_check_error(hipMemset(d_C, 0, group_count * M * N * sizeof(CDataType))); ck_tile::HostTensor c_gpu_ref_host( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - + ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(CLayout{}))); ck_tile::index_t acc_m = 0; for(int i = 0; i < group_count; ++i) { @@ -712,11 +711,12 @@ int run_masked_grouped_flatmm_example_with_layouts( stride_A, stride_B, stride_C); - ck_tile::hip_check_error(hipMemcpy( - c_gpu_ref_host.data(), d_C, m_indices[i] * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_host.data() + i * M * N, + d_C + i * M * N, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); } - ck_tile::hip_check_error(hipFree(d_B)); ck_tile::hip_check_error(hipFree(d_C));