mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Added wrapper and renamed the wmma_v3 instances
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
add_custom_target(example_gemm_add_relu_xdl)
|
||||
|
||||
add_library(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp)
|
||||
|
||||
add_library(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp)
|
||||
|
||||
|
||||
@@ -12,8 +10,6 @@ add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.c
|
||||
|
||||
add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp)
|
||||
|
||||
add_example_executable(example_gemm_add_relu_wmma_v3_fp16 gemm_add_relu_wmma_v3_fp16.cpp)
|
||||
add_example_executable(example_gemm_add_relu_wmma_v3_bf16 gemm_add_relu_wmma_v3_bf16.cpp)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -73,6 +73,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_
|
||||
|
||||
// clang-format on
|
||||
|
||||
#include "run_gemm_add_relu_example_v3.inc"
|
||||
#include "run_gemm_add_relu_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }
|
||||
@@ -71,6 +71,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_
|
||||
|
||||
// clang-format on
|
||||
|
||||
#include "run_gemm_add_relu_example_v3.inc"
|
||||
#include "run_gemm_add_relu_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }
|
||||
@@ -45,30 +45,30 @@ void add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instan
|
||||
|
||||
#elif defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>&);
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>&);
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Add + Relu
|
||||
@@ -137,7 +137,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#endif
|
||||
|
||||
#elif defined(CK_USE_WMMA)
|
||||
// For wmma ADataType must be same as BDatatype.
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
@@ -151,7 +151,6 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
#endif
|
||||
|
||||
// For wmma ADataType must be same as BDatatype.
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
|
||||
is_same_v<D0DataType, ck::bhalf_t> && is_same_v<EDataType, ck::bhalf_t>)
|
||||
|
||||
@@ -4,8 +4,5 @@ add_instance_library(device_gemm_add_relu_instance
|
||||
device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user