Added wrapper and renamed the wmma_v3 instances

This commit is contained in:
apoorva
2025-07-08 11:26:01 +00:00
parent 86ca6b827d
commit 9b64da2298
8 changed files with 25 additions and 33 deletions

View File

@@ -1,9 +1,7 @@
add_custom_target(example_gemm_add_relu_xdl)
add_library(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp)
add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp)
add_library(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp)
add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp)
@@ -12,8 +10,6 @@ add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.c
add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp)
add_example_executable(example_gemm_add_relu_wmma_v3_fp16 gemm_add_relu_wmma_v3_fp16.cpp)
add_example_executable(example_gemm_add_relu_wmma_v3_bf16 gemm_add_relu_wmma_v3_bf16.cpp)

View File

@@ -73,6 +73,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_
// clang-format on
#include "run_gemm_add_relu_example_v3.inc"
#include "run_gemm_add_relu_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }

View File

@@ -71,6 +71,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_
// clang-format on
#include "run_gemm_add_relu_example_v3.inc"
#include "run_gemm_add_relu_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }

View File

@@ -45,30 +45,30 @@ void add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instan
#elif defined(CK_USE_WMMA)
void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddRelu>>>&);
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddRelu>>>&);
void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
AddRelu>>>&);
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
AddRelu>>>&);
#endif
// GEMM + Add + Relu
@@ -137,7 +137,7 @@ struct DeviceOperationInstanceFactory<
#endif
#elif defined(CK_USE_WMMA)
// For wmma ADataType must be same as BDatatype.
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
@@ -151,7 +151,6 @@ struct DeviceOperationInstanceFactory<
}
#endif
// For wmma ADataType must be same as BDatatype.
#if defined(CK_ENABLE_BF16)
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<D0DataType, ck::bhalf_t> && is_same_v<EDataType, ck::bhalf_t>)

View File

@@ -4,8 +4,5 @@ add_instance_library(device_gemm_add_relu_instance
device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
)