External Interface (#304)

* add client example

* clean

* clean

* reorg

* clean up profiler

* reorg

* clea

* fix profiler

* function for getinstances

* update client example

* update client example

* update client example

* update

* update example

* update Jenkins file

* update cmake

* update Jenkins

[ROCm/composable_kernel commit: aebd211c36]
This commit is contained in:
Chao Liu
2022-06-26 19:39:02 -05:00
committed by GitHub
parent 73af96b913
commit 9096feca63
73 changed files with 2199 additions and 1961 deletions

View File

@@ -84,8 +84,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<float, float, float, float, PassThrough, PassThrough, PassThrough>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
int main(int argc, char* argv[])
{
@@ -216,24 +221,17 @@ int main(int argc, char* argv[])
if(do_verification)
{
Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<float> c_m_n_device_f32_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
bf16_to_f32_(a_m_k, a_f32_m_k);
bf16_to_f32_(b_k_n, b_f32_k_n);
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_f32_m_k, b_f32_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1;
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
}
return 0;