mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix reference
This commit is contained in:
@@ -597,7 +597,7 @@ int run_masked_grouped_flatmm_example_with_layouts(
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t M = 4096;//Ms[0];
|
||||
ck_tile::index_t M = 4096; // Ms[0];
|
||||
ck_tile::index_t N = Ns[0];
|
||||
ck_tile::index_t K = Ks[0];
|
||||
|
||||
@@ -683,12 +683,11 @@ int run_masked_grouped_flatmm_example_with_layouts(
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMemset(d_C, 0, M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, group_count * M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMemset(d_C, 0, group_count * M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::index_t acc_m = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
@@ -712,11 +711,12 @@ int run_masked_grouped_flatmm_example_with_layouts(
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
ck_tile::hip_check_error(hipMemcpy(
|
||||
c_gpu_ref_host.data(), d_C, m_indices[i] * N * sizeof(CDataType), hipMemcpyDeviceToHost));
|
||||
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_host.data() + i * M * N,
|
||||
d_C + i * M * N,
|
||||
M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user