mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845)
* Support bf16/fb8/bf8 datatypes for ck_tile/gemm * remove commented out code. * Addressing code review comments and enabling universal_gemm for all the supported data types. * Merge conflict resolution. * Solve the memory pipeline compilation error. Merge with the new change of CShuffle * finish the feature, pass the tests * Fix the pipeline and add the benchmark script for other data types --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
|
||||
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
|
||||
ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = acc;
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user