diff --git a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc index 955769dab0..b72e687ea0 100644 --- a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc +++ b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc @@ -597,7 +597,7 @@ int run_masked_grouped_flatmm_example_with_layouts( } } - ck_tile::index_t M = 4096;//Ms[0]; + ck_tile::index_t M = 4096; // Ms[0]; ck_tile::index_t N = Ns[0]; ck_tile::index_t K = Ks[0]; @@ -683,12 +683,11 @@ int run_masked_grouped_flatmm_example_with_layouts( BDataType* d_B; CDataType* d_C; ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); - ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); - ck_tile::hip_check_error(hipMemset(d_C, 0, M * N * sizeof(CDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, group_count * M * N * sizeof(CDataType))); + ck_tile::hip_check_error(hipMemset(d_C, 0, group_count * M * N * sizeof(CDataType))); ck_tile::HostTensor c_gpu_ref_host( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - + ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(CLayout{}))); ck_tile::index_t acc_m = 0; for(int i = 0; i < group_count; ++i) { @@ -712,11 +711,12 @@ int run_masked_grouped_flatmm_example_with_layouts( stride_A, stride_B, stride_C); - ck_tile::hip_check_error(hipMemcpy( - c_gpu_ref_host.data(), d_C, m_indices[i] * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_host.data() + i * M * N, + d_C + i * M * N, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); } - ck_tile::hip_check_error(hipFree(d_B)); ck_tile::hip_check_error(hipFree(d_C));