fix reference

This commit is contained in:
lalala-sh
2025-07-28 08:11:15 +00:00
parent c585cc1429
commit bfdfd1e148

View File

@@ -597,7 +597,7 @@ int run_masked_grouped_flatmm_example_with_layouts(
}
}
ck_tile::index_t M = 4096;//Ms[0];
ck_tile::index_t M = 4096; // Ms[0];
ck_tile::index_t N = Ns[0];
ck_tile::index_t K = Ks[0];
@@ -683,12 +683,11 @@ int run_masked_grouped_flatmm_example_with_layouts(
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemset(d_C, 0, M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, group_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemset(d_C, 0, group_count * M * N * sizeof(CDataType)));
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(CLayout{})));
ck_tile::index_t acc_m = 0;
for(int i = 0; i < group_count; ++i)
{
@@ -712,11 +711,12 @@ int run_masked_grouped_flatmm_example_with_layouts(
stride_A,
stride_B,
stride_C);
ck_tile::hip_check_error(hipMemcpy(
c_gpu_ref_host.data(), d_C, m_indices[i] * N * sizeof(CDataType), hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_host.data() + i * M * N,
d_C + i * M * N,
M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
}
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));