mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Improve external interface for GEMM and GEMM+add+add+fastgelu (#311)
* interface for GEMM and GEMM+add+add+fastgelu * rename namespace * instance factory * fix build * fix build; add GEMM client example * clean
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
add_executable(gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp)
|
||||
target_link_libraries(gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations)
|
||||
add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp)
|
||||
target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations)
|
||||
|
||||
@@ -160,16 +160,17 @@ int main()
|
||||
ck::index_t StrideC = 1024;
|
||||
ck::index_t StrideD0 = 1024;
|
||||
|
||||
const auto gemm_reduce_ptrs = ck::tensor_operation::device::device_gemm_instance::
|
||||
get_device_gemm_add_add_mean_squaremean_instances<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>();
|
||||
const auto gemm_reduce_ptrs =
|
||||
ck::tensor_operation::device::instance::get_device_gemm_add_add_mean_squaremean_instances<
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>();
|
||||
|
||||
const auto normalize_ptrs =
|
||||
ck::tensor_operation::device::get_device_normalize_from_mean_meansquare_instances<
|
||||
ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
|
||||
CDataType,
|
||||
ReduceDataType,
|
||||
ReduceDataType,
|
||||
@@ -267,4 +268,4 @@ int main()
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user