diff --git a/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt b/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt deleted file mode 100644 index e564f3180d..0000000000 --- a/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) -target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_operations) diff --git a/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000..60543c7308 --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) +target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp) +target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp) +target_link_libraries(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 PRIVATE composable_kernel::device_operations) diff --git a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp similarity index 100% rename from client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp rename to client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp new file mode 100644 index 0000000000..d2f2ff41bc --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 16; +static constexpr ck::index_t K = 16; +static constexpr ck::index_t C = 16; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 14; +static constexpr ck::index_t Hi = 14; +static constexpr ck::index_t Wi = 14; +static constexpr ck::index_t Do = 14; +static constexpr ck::index_t Ho = 14; +static constexpr ck::index_t Wo = 14; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp new file mode 100644 index 0000000000..2330228d1d --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 16; +static constexpr ck::index_t K = 16; +static constexpr ck::index_t C = 16; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 14; +static constexpr ck::index_t Hi = 14; +static constexpr ck::index_t Wi = 14; +static constexpr ck::index_t Do = 14; +static constexpr ck::index_t Ho = 14; +static constexpr ck::index_t Wo = 14; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough, + ck::bf8_t, + ck::f8_t>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e2580a370c..249c2c030f 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -1,5 +1,15 @@ -add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) -add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) +if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) + target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) +endif() + +if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) + target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_operations) +endif() + +if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) + target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) +endif() diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index 449c9466e8..50c03c49d8 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -94,7 +94,8 @@ template + ck::index_t NumNonSpatialDim = 3, + typename ComputeType = InDataType> bool run_grouped_conv_fwd(std::array in_lengths, std::array wei_lengths, std::array out_lengths) @@ -184,7 +185,8 @@ bool run_grouped_conv_fwd(std::array; + PassThrough, + ComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp new file mode 100644 index 0000000000..1651ec2f39 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp index 3b6dd9a2a9..c758720e10 100644 --- a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp +++ b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp @@ -60,14 +60,13 @@ int main() int sum_of_m = 0; - Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - - int group_count = Ms.size(); + const int group_count = 16; for(int i = 0; i < group_count; ++i) { - Ns.push_back(768); - Ks.push_back(4608); + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp index 3503ae8b24..b16fe90387 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp @@ -57,15 +57,13 @@ int main() int sum_of_m = 0; - // Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; - - int group_count = Ms.size(); + const int group_count = 16; for(int i = 0; i < group_count; ++i) { - Ns.push_back(768); - Ks.push_back(4608); + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp index b288550b74..045fe47c4f 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp @@ -58,14 +58,13 @@ int main() int sum_of_m = 0; - Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - - int group_count = Ms.size(); + const int group_count = 16; for(int i = 0; i < group_count; ++i) { - Ns.push_back(768); - Ks.push_back(4608); + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp index c60daa3b36..8f82140f3f 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp @@ -58,14 +58,13 @@ int main() int sum_of_m = 0; - Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - - int group_count = Ms.size(); + const int group_count = 16; for(int i = 0; i < group_count; ++i) { - Ns.push_back(768); - Ks.push_back(4608); + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 89d4789c12..95b8526094 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -296,13 +296,11 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - problem_size.Ms = { - 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ns.push_back(768); - problem_size.Ks.push_back(4608); + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp index 1c50dc051b..84abe1d1db 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp @@ -297,13 +297,11 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - problem_size.Ms = { - 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ns.push_back(768); - problem_size.Ks.push_back(4608); + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index 743ab96be6..9f8f6cb1e4 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -66,13 +66,11 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - problem_size.Ms = { - 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ns.push_back(768); - problem_size.Ks.push_back(4608); + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index ab8ed5c2bf..c83b384a01 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -11,6 +11,12 @@ foreach(gpu IN LISTS GPU_TARGETS) if(result EQUAL 0) add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) endif() + if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) + if(result EQUAL 0) + add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) + endif() + endif() set(target 1) endif() endforeach() diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index 15727495f0..0e7eb00e4b 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -23,6 +23,12 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif template using S = ck::Sequence; diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp index cb6d68e04a..dccde28aa0 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp @@ -65,6 +65,15 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedC 5, // CThreadTransferSrcDstVectorDim 4>; // CThreadTransferDstScalarPerVector +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index 31e277b5c7..467209ef28 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -67,6 +67,15 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp index 69c831cc54..8b751b0d98 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -66,6 +66,15 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000..c678867e74 --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; +using ComputeTypeA = BF8; +using ComputeTypeB = F8; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 2, // CBlockTransferScalarPerVector_NWaveNPerXdl + ComputeTypeA, // ComputeTypeA + ComputeTypeB>; // ComputeTypeB + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 34bca7dc86..755d6addd8 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -1,15 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -template -using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; - template bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_param) @@ -46,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - out.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in.GenerateTensorValue(GeneratorTensor_3{0.0, 0.2}); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); } DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); @@ -113,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return true; } - float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = conv_param.GetFlops(); - std::size_t num_btype = conv_param.GetByte(); - - float tflops = static_cast(flop) / 1.E9 / avg_time; - - float gb_per_sec = num_btype / 1.E6 / avg_time; - - std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl - << "DeviceOp: " << conv.GetTypeString() << std::endl; + invoker.Run(argument, StreamConfig{nullptr, false}); if(config.do_verification) { @@ -148,6 +128,19 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); } + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl + << "DeviceOp: " << conv.GetTypeString() << std::endl; + return true; } diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 1a0489ce9c..167a9f147d 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,52 +1,4 @@ -add_custom_target(example_contraction) -add_custom_target(example_contraction_scale) -add_custom_target(example_contraction_bilinear) - -# FP32 add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) - add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) - -add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) - -add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) - -add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) - -add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) - -# FP64 add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) - add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) - -add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) - -add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) - -# FP16 -add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) - -add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) - -# BF16 -add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) - -add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) - -add_dependencies(example_contraction example_contraction_scale) -add_dependencies(example_contraction example_contraction_bilinear) diff --git a/example/26_contraction/common_instances.hpp b/example/26_contraction/common_instances.hpp deleted file mode 100644 index eb2fa5ce73..0000000000 --- a/example/26_contraction/common_instances.hpp +++ /dev/null @@ -1,183 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" - -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; -using F64 = double; - -template -using S = ck::Sequence; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// Generic instances for fp32, fp16 and bf16 data types. -// clang-format off -template -using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -template -using DeviceOpInstanceKN_Generic = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; - -template -using DeviceOpInstanceMK_Generic = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -template -using DeviceOpInstanceMN_Generic = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; - -// Fp64 instances. -template -using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -template -using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; - -template -using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -template -using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; - -// clang-format on diff --git a/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp deleted file mode 100644 index 77ec5669a7..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = BF16; -using BDataType = BF16; -using AccDataType = F32; -using CShuffleDataType = BF16; -using DDataType = BF16; -using DsDataType = ck::Tuple; -using EDataType = BF16; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp deleted file mode 100644 index 25c4d9eac4..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F16; -using DDataType = F16; -using DsDataType = ck::Tuple; -using EDataType = F16; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp index d440e7394e..78522160c8 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -1,10 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "common_instances.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F32; using BDataType = F32; @@ -13,7 +32,6 @@ using CShuffleDataType = F32; using DDataType = F32; using DsDataType = ck::Tuple; using EDataType = F32; -using ComputeDataType = F32; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -23,64 +41,253 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; +using DeviceOpInstanceKNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; +using DeviceOpInstanceMKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; -#include "run_contraction_bilinear_example.inc" +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp deleted file mode 100644 index 5b748949bd..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F32; -using BDataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F32; -using DsDataType = ck::Tuple; -using EDataType = F32; -using ComputeDataType = BF16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp deleted file mode 100644 index 334cd615d1..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F32; -using BDataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F32; -using DsDataType = ck::Tuple; -using EDataType = F32; -using ComputeDataType = F16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp index 86ea7e0212..6cceed5bc1 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -1,10 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "common_instances.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +template +using S = ck::Sequence; + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F64; using BDataType = F64; @@ -13,7 +32,6 @@ using CShuffleDataType = F64; using DDataType = F64; using DsDataType = ck::Tuple; using EDataType = F64; -using ComputeDataType = F64; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -23,64 +41,253 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; +using DeviceOpInstanceKNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; +using DeviceOpInstanceMKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceMNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; +// clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; -#include "run_contraction_bilinear_example.inc" +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp deleted file mode 100644 index b0a2a9a10f..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F64; -using BDataType = F64; -using AccDataType = F32; -using CShuffleDataType = F64; -using DDataType = F64; -using DsDataType = ck::Tuple; -using EDataType = F64; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp deleted file mode 100644 index 32652fc6f3..0000000000 --- a/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = BF16; -using BDataType = BF16; -using AccDataType = F32; -using CShuffleDataType = BF16; -using DsDataType = ck::Tuple<>; -using EDataType = BF16; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp deleted file mode 100644 index a4797e2400..0000000000 --- a/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F16; -using DsDataType = ck::Tuple<>; -using EDataType = F16; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp index fe984b29d6..1574f5d18f 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -1,10 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "common_instances.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F32; using BDataType = F32; @@ -12,7 +31,6 @@ using AccDataType = F32; using CShuffleDataType = F32; using DsDataType = ck::Tuple<>; using EDataType = F32; -using ComputeDataType = F32; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -22,64 +40,237 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; +// clang-format off +using DeviceOpInstanceKKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; +using DeviceOpInstanceKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; +using DeviceOpInstanceMKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on using DeviceOpInstance = DeviceOpInstanceKKN; -#include "run_contraction_scale_example.inc" +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp deleted file mode 100644 index 83c8fec179..0000000000 --- a/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F32; -using BDataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DsDataType = ck::Tuple<>; -using EDataType = F32; -using ComputeDataType = BF16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp deleted file mode 100644 index 27e42e8203..0000000000 --- a/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F32; -using BDataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DsDataType = ck::Tuple<>; -using EDataType = F32; -using ComputeDataType = F16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp index 37374780b6..3dacc70887 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -1,10 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "common_instances.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +template +using S = ck::Sequence; + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F64; using BDataType = F64; @@ -12,7 +31,6 @@ using AccDataType = F64; using CShuffleDataType = F64; using DsDataType = ck::Tuple<>; using EDataType = F64; -using ComputeDataType = F64; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -22,64 +40,237 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; +// clang-format off +using DeviceOpInstanceKKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; -using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; +using DeviceOpInstanceKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; -using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; +using DeviceOpInstanceMKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceMNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; +// clang-format on using DeviceOpInstance = DeviceOpInstanceKKN; -#include "run_contraction_scale_example.inc" +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp deleted file mode 100644 index 5dbeb23154..0000000000 --- a/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F64; -using BDataType = F64; -using AccDataType = F32; -using CShuffleDataType = F64; -using DsDataType = ck::Tuple<>; -using EDataType = F64; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/run_contraction_bilinear_example.inc b/example/26_contraction/run_contraction_bilinear_example.inc deleted file mode 100644 index 7b9a01e0ae..0000000000 --- a/example/26_contraction/run_contraction_bilinear_example.inc +++ /dev/null @@ -1,234 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -int run_contraction_bilinear_example(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // D[M0, M1, N0, N1] - std::vector d_ms_ns_lengths{30, 128, 32, 64}; - std::vector d_ms_ns_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float alpha = 1.f; - float beta = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 28) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - d_ms_ns_lengths = {M0, M1, N0, N1}; - d_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; - - alpha = std::stof(argv[26]); - beta = std::stof(argv[27]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); - printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg26 to 27: alpha, beta\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - d_device_buf.ToDevice(d_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1), - d_ms_ns(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/26_contraction/run_contraction_scale_example.inc b/example/26_contraction/run_contraction_scale_example.inc deleted file mode 100644 index 65ca182da2..0000000000 --- a/example/26_contraction/run_contraction_scale_example.inc +++ /dev/null @@ -1,217 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -int run_contraction_scale_example(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float scale = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 23) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - scale = std::stof(argv[22]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg22: scale\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{scale}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 0>{}, - std::array, 0>{}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/60_gemm_multiABD/CMakeLists.txt b/example/60_gemm_multiABD/CMakeLists.txt deleted file mode 100644 index 9e2f70649f..0000000000 --- a/example/60_gemm_multiABD/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_multiABD_xdl_fp16 gemm_multiABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() -endif() diff --git a/example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp b/example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp deleted file mode 100644 index 23d41d7cc4..0000000000 --- a/example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp +++ /dev/null @@ -1,361 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F16; -using EDataType = F16; - -using ALayout = Row; -using BLayout = Col; -using DLayout = Row; -using ELayout = Row; - -struct AddScale -{ - static constexpr auto I0 = ck::Number<0>{}; - static constexpr auto I1 = ck::Number<1>{}; - static constexpr auto I2 = ck::Number<2>{}; - static constexpr auto I3 = ck::Number<3>{}; - - __host__ __device__ constexpr void - operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const - { - const auto a0_v_t = ck::vector_type{a0}; - const auto a1_v_t = ck::vector_type{a1}; - - auto r_v_t = ck::vector_type{}; - - r_v_t.AsType()(I0) = - scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); - r_v_t.AsType()(I1) = - scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); - r_v_t.AsType()(I2) = - scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); - r_v_t.AsType()(I3) = - scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); - - a = r_v_t.AsType()[I0]; - } - - __host__ __device__ constexpr void - operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const - { - a = scale * (a0 + a1); - } - - static constexpr ck::index_t vec_len = 4; - - float scale = 1.0; -}; - -struct AlphaBetaAdd -{ - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; - - template - __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; - - template <> - __host__ __device__ constexpr void operator()( - ck::half_t& e, const float& c, const ck::half_t& d) const - { - e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); - }; - - float alpha_; - float beta_; -}; - -using AElementOp = AddScale; -using BElementOp = PassThrough; -using CDEElementOp = AlphaBetaAdd; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ELayout, - ck::Tuple, - ck::Tuple, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 1, - 256, - 256, - 128, - 32, - 8, - 8, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - S<1, 32, 1, 8>, - 8>; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; - - float alpha = 1.0f; - float beta = 1.0f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - alpha = std::stof(argv[4]); - beta = std::stof(argv[5]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - - alpha = std::stof(argv[11]); - beta = std::stof(argv[12]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " - "beta\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; - std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); - DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a0_device_buf.ToDevice(a0_m_k.mData.data()); - a1_device_buf.ToDevice(a1_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{0.2}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, - std::array{b_device_buf.GetDeviceBuffer()}, - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - std::array{StrideA, StrideA}, - std::array{StrideB}, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - Tensor a_m_k({M, K}); - - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); - } - } - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt deleted file mode 100644 index 0c6ce7ae87..0000000000 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() -endif() diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp deleted file mode 100644 index 90558364a4..0000000000 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ /dev/null @@ -1,362 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F16; -using EDataType = F16; - -using ALayout = Row; -using BLayout = Col; -using DLayout = Row; -using ELayout = Row; - -struct AddScale -{ - static constexpr auto I0 = ck::Number<0>{}; - static constexpr auto I1 = ck::Number<1>{}; - static constexpr auto I2 = ck::Number<2>{}; - static constexpr auto I3 = ck::Number<3>{}; - - __host__ __device__ constexpr void - operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const - { - const auto a0_v_t = ck::vector_type{a0}; - const auto a1_v_t = ck::vector_type{a1}; - - auto r_v_t = ck::vector_type{}; - - r_v_t.AsType()(I0) = - scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); - r_v_t.AsType()(I1) = - scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); - r_v_t.AsType()(I2) = - scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); - r_v_t.AsType()(I3) = - scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); - - a = r_v_t.AsType()[I0]; - } - - __host__ __device__ constexpr void - operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const - { - a = scale * (a0 + a1); - } - - // this attribute will force copy_function applying element_wise with vector_type - static constexpr ck::index_t vec_len = 4; - - float scale = 1.0; -}; - -struct AlphaBetaAdd -{ - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; - - template - __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; - - template <> - __host__ __device__ constexpr void operator()( - ck::half_t& e, const float& c, const ck::half_t& d) const - { - e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); - }; - - float alpha_; - float beta_; -}; - -using AElementOp = AddScale; -using BElementOp = PassThrough; -using CDEElementOp = AlphaBetaAdd; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ELayout, - ck::Tuple, - ck::Tuple, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 1, - 256, - 256, - 128, - 32, - 8, - 8, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - S<1, 32, 1, 8>, - 8>; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; - - float alpha = 1.0f; - float beta = 1.0f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - alpha = std::stof(argv[4]); - beta = std::stof(argv[5]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - - alpha = std::stof(argv[11]); - beta = std::stof(argv[12]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " - "beta\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; - std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); - DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a0_device_buf.ToDevice(a0_m_k.mData.data()); - a1_device_buf.ToDevice(a1_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{0.2}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, - std::array{b_device_buf.GetDeviceBuffer()}, - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - std::array{StrideA, StrideA}, - std::array{StrideB}, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - Tensor a_m_k({M, K}); - - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); - } - } - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt deleted file mode 100644 index 0420fdb9c3..0000000000 --- a/example/61_contraction_multi_ABD/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() -endif() diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp deleted file mode 100644 index 6634dacd25..0000000000 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ /dev/null @@ -1,328 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/numeric.hpp" - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using A0DataType = F16; -using A1DataType = F32; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F16; -using EDataType = F16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -struct AlphaBetaAdd -{ - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; - - template - __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; - - template <> - __host__ __device__ constexpr void operator()( - ck::half_t& e, const float& c, const ck::half_t& d) const - { - e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); - }; - - float alpha_; - float beta_; -}; - -struct Multiply -{ - __host__ __device__ constexpr void - operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const - { - a = ck::type_convert(ck::type_convert(a0) * a1); - } -}; - -using AElementOp = Multiply; -using BElementOp = PassThrough; -using CDEElementOp = AlphaBetaAdd; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle< - NumDimM, - NumDimN, - NumDimK, - ck::Tuple, - ck::Tuple, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 1, - 256, - 256, - 128, - 32, - 8, - 8, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - S<1, 32, 1, 8>, - 8>; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - float alpha = 1.0f; - float beta = 1.0f; - - // A0[M0, M1, K0, K1] - std::vector a0_ms_ks_lengths{30, 128, 32, 64}; - std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; - // A1[M1, K1] -> A1[M0, M1, K0, K1] - std::vector a1_ms_ks_lengths{30, 128, 32, 64}; - std::vector a1_ms_ks_strides{0, 64, 0, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; - // D[M0, M1, N0, N1] - std::vector d_ms_ns_lengths{30, 128, 32, 64}; - std::vector d_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; - std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a0_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a1_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a0_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - a1_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a0_device_buf.ToDevice(a0_ms_ks.mData.data()); - a1_device_buf.ToDevice(a1_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - d_device_buf.ToDevice(d_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument( - std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, - std::array{b_device_buf.GetDeviceBuffer()}, - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, - std::array, 2>{a0_ms_ks_strides, a1_ms_ks_strides}, - std::array, 1>{b_ns_ks_lengths}, - std::array, 1>{b_ns_ks_strides}, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_contraction with the specified compilation parameters does " - "not support this problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - if(time_kernel) - { - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(A0DataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; - } - - if(do_verification) - { - - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - - for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1) - { - for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0) - { - for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1) - { - a_element_op(a_ms_ks(m0, m1, k0, k1), - a0_ms_ks(m0, m1, k0, k1), - a1_ms_ks(m0, m1, k0, k1)); - } - } - } - } - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - Tensor empty_tensor(std::vector{}, std::vector{}); - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1), - d_ms_ns(m0, m1, n0, n1)); - } - } - } - } - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index d030c53453..7f8704f281 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -30,7 +30,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(test 0) break() elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND NOT(source MATCHES type OR source MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -59,7 +59,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(result 0) endif() #message("add_example returns ${result}") - return(PROPAGATE result) + set(result ${result} PARENT_SCOPE) endfunction(add_example_executable EXAMPLE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) @@ -87,7 +87,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) set(test 0) break() elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND NOT(source MATCHES type OR source MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -96,7 +96,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) if(test EQUAL 1) message("removing example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") - endif() + endif() endforeach() endif() foreach(source IN LISTS FILE_NAME) @@ -114,7 +114,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) set(result 0) endif() #message("add_example returns ${result}") - return(PROPAGATE result) + set(result ${result} PARENT_SCOPE) endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp deleted file mode 100644 index 3116d40474..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// GEMM: -// input : A0[M0, M1, ... K0, K1, ...], ... -// input : B0[N0, N1, ... K0, K1, ...], ... -// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ... -// output : E[M0, M1, ... N0, N1, ...] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) -// Assume: -// D0, D1, ... and E have the same layout -template -struct DeviceContractionMultipleABD : public BaseOperator -{ - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - virtual std::unique_ptr - MakeArgumentPointer(std::array p_as, - std::array p_bs, - std::array p_ds, - void* p_e, - const std::array, NumATensor>& a_ms_ks_lengths, - const std::array, NumATensor>& a_ms_ks_strides, - const std::array, NumBTensor>& b_ns_ks_lengths, - const std::array, NumBTensor>& b_ns_ks_strides, - const std::array, NumDTensor>& d_ms_ns_lengths, - const std::array, NumDTensor>& d_ms_ns_strides, - const std::vector& e_ms_ns_length, - const std::vector& e_ms_ns_stride, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp index dffb51d0ba..118ade8978 100644 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -33,8 +33,7 @@ template + typename CDEElementwiseOperation> struct DeviceContractionMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp index 7e4bca2bd6..2abf1d5a10 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -29,7 +29,9 @@ template + typename CDEElementwiseOperation, + typename AComputeType = ADataType, + typename BComputeType = AComputeType> struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index ab9e6adb41..7296e4faaa 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -20,7 +20,9 @@ template + typename OutElementwiseOperation, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight : public BaseOperator { virtual std::unique_ptr diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp index 1cc30fd9e6..2ca82dc6da 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -29,7 +29,8 @@ template + typename CDEElementwiseOperation, + typename ComputeType = ADataType> struct DeviceGroupedConvFwdMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp deleted file mode 100644 index a4c6767434..0000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ /dev/null @@ -1,846 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_contraction_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, - const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(p_as_grid, - p_bs_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); -#else - ignore = p_as_grid; - ignore = p_bs_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = as_grid_desc_ak0_m_ak1; - ignore = bs_grid_desc_bk0_n_bk1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_etile_map; -#endif -} - -} // namespace ck - -namespace ck { -namespace tensor_operation { -namespace device { - -// GEMM: -// input : A[M, K] -// input : B[N, K] -// input : D0[M, N], D1[M, N], ... -// output : E[M, N] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) -// Assume: -// D0, D1, ... and E have the same layout -template -struct DeviceContractionMultipleABD_Xdl_CShuffle - : public DeviceContractionMultipleABD -{ - using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - using ComputeDataType = EDataType; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< - AsDataType, - BsDataType, - ComputeDataType, - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched, - PipelineVer>; - - static constexpr auto matrix_padder = - ck::tensor_operation::device::MatrixPadder{ - MPerBlock, NPerBlock, KPerBlock}; - - static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_, - const std::vector& a_ms_ks_strides_) - { - assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK && - a_ms_ks_strides_.size() == NumDimM + NumDimK); - - const auto to_tuple = [&](auto& vec, auto num) { - return generate_tuple([&](auto i) { return vec[i]; }, num); - }; - - const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number{}); - const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number{}); - - // dimension Ids for M0, M1, ... - constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; - - // dimension Ids for K0, K1, ... - constexpr auto kDimIds = - typename arithmetic_sequence_gen::type{}; - - // lengths for M0, M1, ... - const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); - - // lengths for K0, K1, ... - const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); - - // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] - const auto a_grid_desc_ms_ks = - make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); - - // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] - const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( - a_grid_desc_ms_ks, - make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), - make_tuple(mDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - } - - __host__ __device__ static auto - MakeAsGridDescriptor_M_K(const std::array, NumATensor>& as_ms_ks_lengths, - const std::array, NumATensor>& as_ms_ks_strides) - { - return generate_tuple( - [&](auto i) { - return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]); - }, - Number{}); - } - - // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] - static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_, - const std::vector& b_ns_ks_strides_) - { - assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK && - b_ns_ks_strides_.size() == NumDimN + NumDimK); - - const auto to_tuple = [&](auto& vec, auto num) { - return generate_tuple([&](auto i) { return vec[i]; }, num); - }; - - const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number{}); - const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number{}); - - // dimension Ids for N0, N1, ... - constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; - - // dimension Ids for K0, K1, ... - constexpr auto kDimIds = - typename arithmetic_sequence_gen::type{}; - - // lengths for K0, K1, ... - const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); - - // lengths for N0, N1, ... - const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); - - // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] - const auto b_grid_desc_ns_ks = - make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); - - // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] - const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( - b_grid_desc_ns_ks, - make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), - make_tuple(nDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - } - - __host__ __device__ static auto - MakeBsGridDescriptor_N_K(const std::array, NumBTensor>& bs_ns_ks_lengths, - const std::array, NumBTensor>& bs_ns_ks_strides) - { - return generate_tuple( - [&](auto i) { - return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]); - }, - Number{}); - } - - // assume E[M0, M1, M2, ..., N0, N1, N2...] - static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_, - const std::vector& e_ms_ns_strides_) - { - assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN && - e_ms_ns_strides_.size() == NumDimM + NumDimN); - - const auto to_tuple = [&](auto& vec, auto num) { - return generate_tuple([&](auto i) { return vec[i]; }, num); - }; - - const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number{}); - const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number{}); - - // dimension Ids for M0, M1, ... - constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; - - // dimension Ids for N0, N1, ... - constexpr auto nDimIds = - typename arithmetic_sequence_gen::type{}; - - // lengths for M0, M1, ... - const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); - - // lengths for K0, K1, ... - const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); - - // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] - const auto e_grid_desc_ms_ns = - make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); - - // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] - const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( - e_grid_desc_ms_ns, - make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), - make_tuple(mDimIds, nDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); - } - - static auto - MakeDsGridDescriptor_M_N(const std::array, NumDTensor>& ds_ms_ns_lengths, - const std::array, NumDTensor>& ds_ms_ns_strides) - { - return generate_tuple( - [&](auto i) { - return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); - }, - Number{}); - } - - // desc for problem definition - using AsGridDesc_M_K = remove_cvref_t; - using BsGridDesc_N_K = remove_cvref_t; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t; - - // desc for blockwise copy - using AsGridDesc_AK0_M_AK1 = - remove_cvref_t; - using BsGridDesc_BK0_N_BK1 = - remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - - // block-to-e-tile map - using Block2ETileMap = - remove_cvref_t; - - // Argument - struct Argument : public BaseArgument - { - Argument(std::array p_as_grid, - std::array p_bs_grid, - std::array p_ds_grid, - void* p_e_grid, - const std::array, NumATensor>& a_ms_ks_lengths, - const std::array, NumATensor>& a_ms_ks_strides, - const std::array, NumBTensor>& b_ns_ks_lengths, - const std::array, NumBTensor>& b_ns_ks_strides, - const std::array, NumDTensor>& d_ms_ns_lengths, - const std::array, NumDTensor>& d_ms_ns_strides, - const std::vector& e_ms_ns_length, - const std::vector& e_ms_ns_stride, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : p_as_grid_{}, - p_bs_grid_{}, - p_ds_grid_{}, - p_e_grid_{static_cast(p_e_grid)}, - as_grid_desc_m_k_{}, - bs_grid_desc_n_k_{}, - ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)}, - as_grid_desc_ak0_m_ak1_{}, - bs_grid_desc_bk0_n_bk1_{}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - // populate pointer, desc for As - static_for<0, NumATensor, 1>{}([&](auto i) { - // using ALayout = remove_cvref_t>; - using ADataType = remove_cvref_t>; - - // A pointer - p_as_grid_(i) = static_cast(p_as_grid[i]); - - // A desc - as_grid_desc_m_k_(i) = - MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]); - }); - - // populate pointer, desc for Bs - static_for<0, NumBTensor, 1>{}([&](auto i) { - // using BLayout = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - // B pointer - p_bs_grid_(i) = static_cast(p_bs_grid[i]); - - // B desc - bs_grid_desc_n_k_(i) = - MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]); - }); - - // populate pointer, desc for Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - // using DLayout = remove_cvref_t>; - using DDataType = remove_cvref_t>; - - // D pointer - p_ds_grid_(i) = static_cast(p_ds_grid[i]); - - // D desc - ds_grid_desc_m_n_(i) = - MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]); - }); - - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, - bs_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - as_grid_desc_ak0_m_ak1_ = - GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); - - bs_grid_desc_bk0_n_bk1_ = - GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } - - // for sanity check of vector memory access - for(index_t i = 0; i < NumATensor; ++i) - { - a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1]; - a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1]; - } - - for(index_t i = 0; i < NumBTensor; ++i) - { - b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1]; - b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1]; - } - - for(index_t i = 0; i < NumDTensor; ++i) - { - ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1]; - } - - e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1]; - } - - // pointers - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - EDataType* p_e_grid_; - - // tensor descriptors for problem definiton - AsGridDesc_M_K as_grid_desc_m_k_; - BsGridDesc_N_K bs_grid_desc_n_k_; - DsGridDesc_M_N ds_grid_desc_m_n_; - EGridDesc_M_N e_grid_desc_m_n_; - - // tensor descriptors for block/thread-wise copy - AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; - BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; - DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - - // block-to-e-tile map - Block2ETileMap block_2_etile_map_; - - // element-wise op - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - - // Strides for the last M/N/K dimensions of A/B/Ds/E - // for sanity check of vector load/store - std::array a_mz_stride_; - std::array a_kz_stride_; - - std::array b_nz_stride_; - std::array b_kz_stride_; - - std::array ds_nz_stride_; - index_t e_nz_stride_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, - arg.bs_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); - - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - - const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle< - GridwiseGemm, - typename GridwiseGemm::AsGridPointer, - typename GridwiseGemm::BsGridPointer, - typename GridwiseGemm::DsGridPointer, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - DeviceOp::AsGridDesc_AK0_M_AK1, - DeviceOp::BsGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::Block2ETileMap, - has_main_loop>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_as_grid_, - arg.p_bs_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.as_grid_desc_ak0_m_ak1_, - arg.bs_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_etile_map_); - }; - - const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_xdl_supported()) - { - return false; - } - - // check vector load/store - { - bool all_valid = true; - - static_for<0, NumATensor, 1>{}([&](auto i) { - // vector memory access of A: could be on M or AK1 dimension - if constexpr(ABlockTransferSrcVectorDim == 1) - { - if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % - ABlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - else - { - if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) % - ABlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - }); - - // vector memory access of B: could be on N or BK1 dimension - static_for<0, NumBTensor, 1>{}([&](auto i) { - if constexpr(BBlockTransferSrcVectorDim == 1) - { - if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % - BBlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - else - { - if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) % - BBlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - }); - - // check vector load of Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - if(!(arg.ds_nz_stride_[i] == 1 && - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) - { - all_valid = false; - } - }); - - // vector memory access of E: always on NPerBlock dimension - if(!(arg.e_nz_stride_ == 1 && - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) - { - all_valid = false; - } - - if(!all_valid) - { - return false; - } - } - - return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, - arg.bs_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(std::array p_as, - std::array p_bs, - std::array p_ds, - void* p_e, - const std::array, NumATensor>& a_ms_ks_lengths, - const std::array, NumATensor>& a_ms_ks_strides, - const std::array, NumBTensor>& b_ns_ks_lengths, - const std::array, NumBTensor>& b_ns_ks_strides, - const std::array, NumDTensor>& d_ms_ns_lengths, - const std::array, NumDTensor>& d_ms_ns_strides, - const std::vector& e_ms_ns_length, - const std::vector& e_ms_ns_stride, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{p_as, - p_bs, - p_ds, - p_e, - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - d_ms_ns_lengths, - d_ms_ns_strides, - e_ms_ns_length, - e_ms_ns_stride, - a_element_op, - b_element_op, - cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr - MakeArgumentPointer(std::array p_as, - std::array p_bs, - std::array p_ds, - void* p_e, - const std::array, NumATensor>& as_ms_ks_lengths, - const std::array, NumATensor>& as_ms_ks_strides, - const std::array, NumBTensor>& bs_ns_ks_lengths, - const std::array, NumBTensor>& bs_ns_ks_strides, - const std::array, NumDTensor>& ds_ms_ns_lengths, - const std::array, NumDTensor>& ds_ms_ns_strides, - const std::vector& e_ms_ns_length, - const std::vector& e_ms_ns_stride, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) override - { - return std::make_unique(p_as, - p_bs, - p_ds, - p_e, - as_ms_ks_lengths, - as_ms_ks_strides, - bs_ns_ks_lengths, - bs_ns_ks_strides, - ds_ms_ns_lengths, - ds_ms_ns_strides, - e_ms_ns_length, - e_ms_ns_stride, - a_element_op, - b_element_op, - cde_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - std::map LoopSchedToString{ - {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; - - std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, - {PipelineVersion::v2, "v2"}}; - - // clang-format off - str << "DeviceContractionMultipleABD_Xdl_CShuffle" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 << ", " - << MPerXDL << ", " - << NPerXDL << ", " - << MXdlPerWave << ", " - << NXdlPerWave << ", " - << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector << ", " - << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle << ", " - << getGemmSpecializationString(GemmSpec) - << ">" - << " LoopScheduler: " - << LoopSchedToString[LoopSched] << ", " - << "PipelineVersion: " - << PipelineVersionToString[PipelineVer]; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index db119178b7..5ae06836fa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -112,7 +112,6 @@ template + CDEElementwiseOperation> { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; @@ -312,6 +310,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 37d7b0b1fc..88b859efbc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -198,7 +198,9 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename AComputeType = ADataType, + typename BComputeType = AComputeType> struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : public DeviceGroupedConvBwdDataMultipleD + CDEElementwiseOp, + AComputeType, + BComputeType> { // TODO: Extend support for more spatial dimensions. static_assert(NDimSpatial == 2 || NDimSpatial == 3, @@ -312,9 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ABDataType, // TODO: distinguish A/B datatype - ABDataType, // TODO: distinguish A/B datatype - ABDataType, // TODO: distinguish A/B datatype + ABDataType, + ABDataType, + AComputeType, AccDataType, CShuffleDataType, DsDataType, @@ -354,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; + LoopSched, + PipelineVersion::v1, + BComputeType>; template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 71c528e4c9..d4bba3bf64 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch } // namespace template (compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; GridwiseGemm::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, @@ -163,7 +164,9 @@ template + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight_Xdl_CShuffle : public DeviceGroupedConvBwdWeight + OutElementwiseOperation, + ComputeTypeA, + ComputeTypeB> { using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; @@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, @@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, - true>; + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle index_t M01_; index_t N01_; - InElementwiseOperation a_element_op_; - OutElementwiseOperation b_element_op_; + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; WeiElementwiseOperation c_element_op_; // for checking IsSupportedArgument() @@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype + ADataType, + BDataType, CDataType, OutElementwiseOperation, InElementwiseOperation, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index b04594fceb..f4b8d66ecf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -211,7 +211,8 @@ template + typename ComputeDataType = ADataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleD + CDEElementwiseOperation, + ComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle; @@ -323,8 +325,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = remove_cvref_t({}, {}))>; - using ComputeDataType = ADataType; - // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 1b9e37b74a..9fe0931cba 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -186,25 +186,6 @@ struct Bilinear y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); }; - template <> - __host__ __device__ constexpr void - operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const - { - const float x0_tmp = type_convert(x0); - const float x1_tmp = type_convert(x1); - const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp; - y = type_convert(y_tmp); - }; - - template <> - __host__ __device__ constexpr void - operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const - { - const float x1_tmp = ck::type_convert(x1); - const float y_tmp = alpha_ * x0 + beta_ * x1_tmp; - y = y_tmp; - }; - template <> __host__ __device__ constexpr void operator()( std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 927574dfde..0b0e731c60 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -33,12 +33,6 @@ struct PassThrough y = type_convert(x); } - template <> - __host__ __device__ void operator()(double& y, const float& x) const - { - y = type_convert(x); - } - template <> __host__ __device__ void operator()(float& y, const float& x) const { @@ -75,12 +69,6 @@ struct PassThrough y = type_convert(x); } - template <> - __host__ __device__ void operator()(float& y, const bhalf_t& x) const - { - y = type_convert(x); - } - template <> __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const { @@ -185,7 +173,8 @@ struct PassThrough template <> __host__ __device__ void operator()(bf8_t& y, const half_t& x) const { - y = type_convert(x); + // to-do: fix half_t to bf8_t convert + y = ck::type_convert(ck::type_convert(x)); } #endif }; @@ -242,20 +231,6 @@ struct Scale template __host__ __device__ void operator()(Y& y, const X& x) const; - template <> - __host__ __device__ void operator()(half_t& y, const half_t& x) const - { - y = ck::type_convert(scale_) * x; - }; - - template <> - __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const - { - const float x_tmp = ck::type_convert(x); - const float y_tmp = scale_ * x_tmp; - y = ck::type_convert(y_tmp); - }; - template <> __host__ __device__ void operator()(float& y, const float& x) const { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index e7e012c009..10a0842dac 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle [&](auto i) { using ALayout = remove_cvref_t>; - return MakeAGridDescriptor_M_K(MRaws[i], KRaws[i], AsStride[i]); + return MakeAGridDescriptor_M_N(MRaws[i], KRaws[i], AsStride[i]); }, Number{}); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index c7ec435477..15c30a0dad 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -31,7 +31,7 @@ namespace ck { // D0, D1, ... and E have the same layout template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType = AComputeDataType_> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -100,10 +101,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle decltype(GridwiseGemmPipeline_Selector())>; #if CK_WORKAROUND_DENORM_FIX - using ComputeDataType = - conditional_t, ck::bhalf_t, ComputeDataType_>; + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; #else - using ComputeDataType = ComputeDataType_; + using AComputeDataType = AComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -172,8 +173,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ComputeDataType), + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -502,7 +503,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ADataType, - ComputeDataType, + AComputeDataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -533,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BDataType, - ComputeDataType, + BComputeDataType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -561,14 +562,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeDataType, - ComputeDataType, + AComputeDataType, + BComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -586,10 +588,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 8897ce5a10..06c87d1892 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen } template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight { static constexpr auto I0 = Number<0>{}; @@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // denorm test fix, required to work around fp16 mfma issue // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update - // FloatABAdjusted -> FloatAB throughout this file + // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, + // throughout this file #if CK_WORKAROUND_DENORM_FIX - using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; + using FloatAAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeA>; + using FloatBAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeB>; #else - using FloatABAdjusted = FloatAB; + using FloatAAdjusted = ComputeTypeA; + using FloatBAdjusted = ComputeTypeB; #endif // M0/M1/M1Padding @@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + return math::max((a_block_space_size * sizeof(FloatAAdjusted) + + b_block_space_size * sizeof(FloatBAdjusted)), c_block_size * sizeof(FloatC)); } @@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, @@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, MPerBlock, K1>, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatA, + FloatAAdjusted, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, NPerBlock, K1>, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatB, + FloatBAdjusted, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -733,12 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // sanity check constexpr index_t KPack = - math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + math::max(K1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size, + static_cast(p_shared) + a_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize()); // gridwise GEMM pipeline diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 06b95aef21..c8e56fbc56 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -32,8 +32,12 @@ enum struct MfmaInstr mfma_f64_16x16x4f64, mfma_f32_32x32x16f8f8, mfma_f32_16x16x32f8f8, + mfma_f32_32x32x16bf8bf8, + mfma_f32_16x16x32bf8bf8, mfma_f32_32x32x16f8bf8, - mfma_f32_16x16x32f8bf8 + mfma_f32_16x16x32f8bf8, + mfma_f32_32x32x16bf8f8, + mfma_f32_16x16x32bf8f8 }; template @@ -504,6 +508,52 @@ struct mfma_type }; #endif +#if defined CK_ENABLE_BF8 +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8bf8::Run(a, b, reg_c); + } +}; +#endif + #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template <> struct mfma_type @@ -550,6 +600,52 @@ struct mfma_type }; #endif +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8f8::Run(a, b, reg_c); + } +}; +#endif + template + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8bf8; + } +#endif + #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template <> static constexpr auto GetMfma() @@ -724,6 +834,20 @@ struct MfmaSelector } #endif +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8f8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8f8; + } +#endif + static constexpr auto selected_mfma = mfma_type()>{}; @@ -931,8 +1055,12 @@ struct XdlopsGemm #if defined CK_ENABLE_FP8 || is_same::value #endif +#if defined CK_ENABLE_BF8 + || is_same::value +#endif #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 - || (is_same::value && is_same::value) + || (is_same::value && is_same::value) || + (is_same::value && is_same::value) #endif , "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 3768f92633..bf6241e46e 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -420,6 +420,71 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> }; #endif +#if defined CK_ENABLE_BF8 +template +struct intrin_mfma_f32_32x32x16bf8bf8; + +template <> +struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8bf8; + +template <> +struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; +#endif + #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template struct intrin_mfma_f32_32x32x16f8bf8; @@ -484,5 +549,70 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> } }; #endif + +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +template +struct intrin_mfma_f32_32x32x16bf8f8; + +template <> +struct intrin_mfma_f32_32x32x16bf8f8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8f8; + +template <> +struct intrin_mfma_f32_16x16x32bf8f8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; +#endif } // namespace ck #endif diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 935926070d..505de73ae2 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -221,7 +221,7 @@ inline __host__ __device__ bf8_t type_convert(half_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion - return type_convert(type_convert(x)); + return type_convert(type_convert(x)); #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp index 527dac6d3b..92a8c82a6e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp @@ -23,7 +23,6 @@ template = false> @@ -70,24 +69,19 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base { for(ck::index_t k1 = 0; k1 < K1; ++k1) { - // Simulate the possible casting when ComputeDataType is different than the - // A/B data types - ComputeDataType v_a_compute_input = - ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1)); - ComputeDataType v_b_compute_input = - ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1)); - AccDataType v_a; AccDataType v_b; - arg.a_element_op_(v_a, ck::type_convert(v_a_compute_input)); - arg.b_element_op_(v_b, ck::type_convert(v_b_compute_input)); + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); v_acc += v_a * v_b; } } - arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert(v_acc); + arg.c_ms_ns_(m0, m1, n0, n1) = v_acc; }; make_ParallelTensorFunctor(f_ms_ns, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp index ec5df238ab..02ad7a033a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp @@ -25,6 +25,8 @@ template = 1 && NDimSpatial <= 3, bool>::type = false> struct ReferenceConvBwdWeight : public device::BaseOperator { @@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator if(wi >= 0 && ck::type_convert(wi) < arg.input_.GetLengths()[3]) { - float v_out; - float v_in; + ComputeTypeA v_out; + ComputeTypeB v_in; arg.out_element_op_( v_out, ck::type_convert(arg.output_(g, n, k, wo))); @@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator arg.in_element_op_( v_in, ck::type_convert(arg.input_(g, n, c, wi))); - v_acc += v_out * v_in; + v_acc += type_convert(v_out) * type_convert(v_in); } } } @@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator wi >= 0 && ck::type_convert(wi) < arg.input_.GetLengths()[4]) { - float v_out; - float v_in; + ComputeTypeA v_out; + ComputeTypeB v_in; arg.out_element_op_( v_out, @@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator arg.in_element_op_( v_in, ck::type_convert(arg.input_(g, n, c, hi, wi))); - v_acc += v_out * v_in; + v_acc += type_convert(v_out) * type_convert(v_in); } } } @@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ck::type_convert(wi) < arg.input_.GetLengths()[5]) { - float v_out; - float v_in; + ComputeTypeA v_out; + ComputeTypeB v_in; arg.out_element_op_(v_out, ck::type_convert( @@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ck::type_convert( arg.input_(g, n, c, di, hi, wi))); - v_acc += v_out * v_in; + v_acc += + type_convert(v_out) * type_convert(v_in); } } } diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index f6fb3f9eee..ea11fd2e1a 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -29,8 +29,6 @@ using BF8 = ck::bf8_t; using Empty_Tuple = ck::Tuple<>; -using BF16_Tuple = ck::Tuple; - using F16_Tuple = ck::Tuple; using F16_F16_Tuple = ck::Tuple; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp deleted file mode 100644 index 96f4d4e30e..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp +++ /dev/null @@ -1,292 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#pragma once - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; -using F64 = double; - -using F16_Tuple = ck::Tuple; -using BF16_Tuple = ck::Tuple; -using F32_Tuple = ck::Tuple; -using F64_Tuple = ck::Tuple; -using Empty_Tuple = ck::Tuple<>; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Bilinear = ck::tensor_operation::element_wise::Bilinear; -using Scale = ck::tensor_operation::element_wise::Scale; - -template -using S = ck::Sequence; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -template -using device_contraction_kk_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> - // clang-format on - >; - -template -using device_contraction_kn_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> - // clang-format on - >; - -template -using device_contraction_mk_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> - // clang-format on - >; - -template -using device_contraction_mn_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> - // clang-format on - >; - -template -using device_contraction_f64_kk_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> - // clang-format on - >; - -template -using device_contraction_f64_kn_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> - // clang-format on - >; - -template -using device_contraction_f64_mk_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> - // clang-format on - >; - -template -using device_contraction_f64_mn_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp index d06cab1196..bd3af891ec 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -17,6 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { #ifdef CK_ENABLE_FP32 +// float void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP32 - + Bilinear>>>& instances); +#endif #ifdef CK_ENABLE_FP64 +// double void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP64 - -#ifdef CK_ENABLE_FP16 -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP16 - -#ifdef CK_ENABLE_BF16 -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP16 - + Bilinear>>>& instances); +#endif // Contraction + Bilinear template + typename EDataType> struct DeviceOperationInstanceFactory> + ck::tensor_operation::element_wise::Bilinear>> { using DeviceOp = DeviceContractionMultipleD; + ck::tensor_operation::element_wise::Bilinear>; static auto GetInstances() { std::vector> op_ptrs; #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) { - if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( - op_ptrs); - } + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + op_ptrs); } } -#endif // CK_ENABLE_FP32 +#endif #ifdef CK_ENABLE_FP64 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) { - if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( - op_ptrs); - } + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + op_ptrs); } } -#endif // CK_ENABLE_FP64 -#ifdef CK_ENABLE_FP16 - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) - { - if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( - op_ptrs); - } - } - } -#endif // CK_ENABLE_FP16 -#ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) - { - if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( - op_ptrs); - } - } - } -#endif // CK_ENABLE_BF16 +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp index 8e994d61ae..4dd735d062 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -17,6 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { #ifdef CK_ENABLE_FP32 +// float void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP32 - + Scale>>>& instances); +#endif #ifdef CK_ENABLE_FP64 +// double void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP64 - -#ifdef CK_ENABLE_FP16 -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP16 - -#ifdef CK_ENABLE_BF16 -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP16 - + Scale>>>& instances); +#endif // Contraction + Scale template + typename EDataType> struct DeviceOperationInstanceFactory> + ck::tensor_operation::element_wise::Scale>> { using DeviceOp = DeviceContractionMultipleD; + ck::tensor_operation::element_wise::Scale>; static auto GetInstances() { @@ -430,113 +155,34 @@ struct DeviceOperationInstanceFactory) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( - op_ptrs); - } + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + op_ptrs); } } -#endif // CK_ENABLE_FP32 +#endif #ifdef CK_ENABLE_FP64 if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) { - if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( - op_ptrs); - } + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( + op_ptrs); } } -#endif // CK_ENABLE_FP64 -#ifdef CK_ENABLE_FP16 - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) - { - if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( - op_ptrs); - } - } - } -#endif // CK_ENABLE_FP16 -#ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) - { - if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( - op_ptrs); - } - } - } -#endif // CK_ENABLE_BF16 +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 26acf4f5f7..ee86950992 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -18,6 +18,8 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using BF8 = ck::bf8_t; +using F8 = ck::f8_t; using Empty_Tuple = ck::Tuple<>; @@ -143,6 +145,43 @@ using device_grouped_conv_bwd_data_xdl_f32_instances = // clang-format on >; +// f16_f16_f16_comp_f8 +template +using device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, BF8, F8> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp index 6fc91565b9..096e0b1770 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp @@ -19,6 +19,14 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + using Empty_Tuple = ck::Tuple<>; template @@ -133,6 +141,43 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances = std::tuple< +// clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, + // instance for small conv.K + // for fp16 conv.K and conv.C must be divisible by 2 + // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 23edf35e98..17bb4256ac 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -13,6 +13,10 @@ namespace tensor_operation { namespace device { namespace instance { +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; @@ -174,6 +178,42 @@ using device_grouped_conv_fwd_xdl_int8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_f16_comp_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_FP8 + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index b259c640da..5cdaa48bec 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -426,13 +426,32 @@ void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_ins PassThrough, PassThrough>>>& instances); #endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + std::vector>>& instances); +#endif template + typename InDataType, + typename ComputeTypeA, + typename ComputeTypeB> struct DeviceOperationInstanceFactory< ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD< NumDimSpatial, @@ -446,7 +465,9 @@ struct DeviceOperationInstanceFactory< InDataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough>> + ck::tensor_operation::element_wise::PassThrough, + ComputeTypeA, + ComputeTypeB>> { using DeviceOp = DeviceGroupedConvBwdDataMultipleD; + ck::tensor_operation::element_wise::PassThrough, + ComputeTypeA, + ComputeTypeB>; static auto GetInstances() { @@ -597,7 +620,8 @@ struct DeviceOperationInstanceFactory< { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); @@ -607,6 +631,15 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + op_ptrs); + } +#endif #ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index c0a5b4bfe5..9f536f210d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -216,6 +216,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances PassThrough, PassThrough>>>& instances); #endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances); +#endif #ifdef DL_KERNELS // dl @@ -464,7 +479,9 @@ template + typename OutDataType, + typename ComputeTypeA, + typename ComputeTypeB> struct DeviceOperationInstanceFactory> + ck::tensor_operation::element_wise::PassThrough, + ComputeTypeA, + ComputeTypeB>> { using DeviceOp = DeviceGroupedConvBwdWeight; + ck::tensor_operation::element_wise::PassThrough, + ComputeTypeA, + ComputeTypeB>; static auto GetInstances() { @@ -706,7 +727,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( @@ -728,6 +751,15 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + op_ptrs); + } #endif } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index e6dbd349dd..888c00f900 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -16,6 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { + #ifdef CK_ENABLE_BF16 // grouped conv1d forward, GNWC/GKXC/GNWK void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( @@ -32,6 +33,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( PassThrough, PassThrough>>>& instances); #endif + #ifdef CK_ENABLE_FP16 void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( std::vector>>& instances); #endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( std::vector>>& instances); #endif + #ifdef CK_ENABLE_INT8 void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( std::vector>>& instances); #endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_BF16 // grouped conv2d forward, GNHWC/GKYXC/GNHWK void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( @@ -93,6 +162,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( PassThrough, PassThrough>>>& instances); #endif + #ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); #endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances); #endif -#ifdef DL_KERNELS -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#endif + #ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances); - -#ifdef DL_KERNELS -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif #endif + #ifdef CK_ENABLE_INT8 void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( std::vector>>& instances); #endif -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif // grouped conv2d forward, NHWGC/GKYXC/NHWGK #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( @@ -317,6 +285,7 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( PassThrough, PassThrough>>>& instances); #endif + #ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( - std::vector>>& instances); -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); #endif + #ifdef CK_ENABLE_BF16 // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( @@ -476,6 +390,7 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( PassThrough, PassThrough>>>& instances); #endif + #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances); #endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( std::vector>>& instances); #endif + #ifdef CK_ENABLE_INT8 void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( std::vector>>& instances); #endif + #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -649,6 +567,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( PassThrough, PassThrough>>>& instances); #endif + #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( std::vector>>& instances); #endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); #endif + #ifdef CK_ENABLE_INT8 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( std::vector>>& instances); #endif +#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); + +#endif + +#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif + template + typename OutDataType, + typename ComputeType> struct DeviceOperationInstanceFactory> + ck::tensor_operation::element_wise::PassThrough, + ComputeType>> { using DeviceOp = DeviceGroupedConvFwdMultipleD; + ck::tensor_operation::element_wise::PassThrough, + ComputeType>; static auto GetInstances() { @@ -877,33 +955,46 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); -#endif } #endif + +#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } +#endif + #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); -#endif add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); } #endif + +#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + } +#endif + #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v) @@ -911,9 +1002,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); @@ -922,33 +1014,43 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) { + #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); -#endif } #endif + +#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + } +#endif + #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); -#endif - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(op_ptrs); } #endif + +#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + } +#endif + #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v) @@ -967,8 +1069,9 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && @@ -1010,8 +1113,9 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && @@ -1020,9 +1124,18 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( + op_ptrs); + } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/library/include/ck/library/utility/host_tensor_generator.hpp index 07d3a35a18..6b531c4c6f 100644 --- a/library/include/ck/library/utility/host_tensor_generator.hpp +++ b/library/include/ck/library/utility/host_tensor_generator.hpp @@ -111,6 +111,22 @@ struct GeneratorTensor_2 }; #endif +#if defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bf8_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; +#endif + template struct GeneratorTensor_3 { @@ -162,6 +178,25 @@ struct GeneratorTensor_3 }; #endif +#if defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bf8_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; +#endif + template struct GeneratorTensor_4 { diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index a0478c9f09..9c1f72eca3 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -24,7 +24,7 @@ function(add_instance_library INSTANCE_NAME) set(test 0) break() elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND NOT(source MATCHES type OR source MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -51,7 +51,7 @@ function(add_instance_library INSTANCE_NAME) set(result 0) endif() #message("add_instance_library returns ${result}") - return(PROPAGATE result) + set(result ${result} PARENT_SCOPE) endfunction(add_instance_library INSTANCE_NAME) diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 87a6bbba4e..ee9c419e1a 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,42 +1,14 @@ set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) - -# FP32 +#float list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp) - -# FP64 +#double list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp) - -# FP16 -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp) - -# BF16 -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp) + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp deleted file mode 100644 index 56dc1d2c71..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance = - device_contraction_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp deleted file mode 100644 index 7b8163814e..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance = - device_contraction_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp deleted file mode 100644 index ba202202f9..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance = - device_contraction_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp deleted file mode 100644 index b6acab1c18..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp deleted file mode 100644 index 1bd3495b36..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance = - device_contraction_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp deleted file mode 100644 index 6817546d3d..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance = - device_contraction_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp deleted file mode 100644 index cedace2cb2..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance = - device_contraction_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp deleted file mode 100644 index b8383e5268..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp deleted file mode 100644 index 2e73bdb804..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance = - device_contraction_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp deleted file mode 100644 index ae69db03ac..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance = - device_contraction_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp deleted file mode 100644 index fc6b6003f7..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance = - device_contraction_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp deleted file mode 100644 index b81984acc4..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp deleted file mode 100644 index 8da6a79f0a..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance = - device_contraction_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp deleted file mode 100644 index b05c443e85..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance = - device_contraction_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp deleted file mode 100644 index c2cbb9f25f..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance = - device_contraction_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp deleted file mode 100644 index 887eaad7eb..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp index 9a3c71acb5..5587db77e0 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,40 @@ namespace tensor_operation { namespace device { namespace instance { +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = - device_contraction_kk_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp index 5659d477d4..26262855ea 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = - device_contraction_kn_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp index 7cbe1e34be..befc0dcd10 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = - device_contraction_mk_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp index e26dcf394b..e45b47cf94 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = - device_contraction_mn_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp deleted file mode 100644 index 74c5c9c653..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance = - device_contraction_f64_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp deleted file mode 100644 index 8a2705f722..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance = - device_contraction_f64_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp deleted file mode 100644 index 47924a843d..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance = - device_contraction_f64_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp deleted file mode 100644 index d6c684e2b0..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance = - device_contraction_f64_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp index 75bd8e3f4a..f437a227d5 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance = - device_contraction_f64_kk_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp index c515f37916..13fdbeb35c 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance = - device_contraction_f64_kn_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp index ff9fa61e0b..95ef8c4929 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance = - device_contraction_f64_mk_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp index f274c45b58..290f81d7c9 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance = - device_contraction_f64_mn_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index a0918d9d3f..673fae3373 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,43 +1,15 @@ set(DEVICE_CONTRACTION_SCALE_INSTANCES) - -# FP32 +#float list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp) -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp) - -# FP64 +#double list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp) -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp) - -# FP16 -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp) - -# BF16 -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp) - add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp deleted file mode 100644 index ace4d4a330..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance = - device_contraction_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp deleted file mode 100644 index d2a1d45f7a..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance = - device_contraction_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp deleted file mode 100644 index aa4e4df9d9..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance = - device_contraction_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp deleted file mode 100644 index a59bce0a1f..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp deleted file mode 100644 index 876acd3f69..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance = - device_contraction_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp deleted file mode 100644 index 96e961d3bb..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance = - device_contraction_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp deleted file mode 100644 index 69cae90889..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance = - device_contraction_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp deleted file mode 100644 index e6ac98273a..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp deleted file mode 100644 index da1df4c43e..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance = - device_contraction_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp deleted file mode 100644 index 21c4e87fb1..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance = - device_contraction_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp deleted file mode 100644 index 9d22f8feff..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance = - device_contraction_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp deleted file mode 100644 index 8e861f74c9..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp deleted file mode 100644 index 390b5526e3..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance = - device_contraction_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp deleted file mode 100644 index 385a13cc25..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance = - device_contraction_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp deleted file mode 100644 index 1afd432bac..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance = - device_contraction_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp deleted file mode 100644 index bfb4d4630f..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp index 51bc64eaad..16fd1cb407 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,40 @@ namespace tensor_operation { namespace device { namespace instance { -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance = - device_contraction_kk_instance; +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp index 4cd28d52dd..ff37bf7cce 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance = - device_contraction_kn_instance; +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp index 82bba62715..8a1f6f9334 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance = - device_contraction_mk_instance; +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp index 06bf4f34b5..d333f59726 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance = - device_contraction_mn_instance; +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp deleted file mode 100644 index c047153bd9..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance = - device_contraction_f64_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp deleted file mode 100644 index ee33c56223..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance = - device_contraction_f64_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp deleted file mode 100644 index b62d4bb2aa..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance = - device_contraction_f64_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp deleted file mode 100644 index 346f91beb0..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance = - device_contraction_f64_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp index 30cc3b9652..4c87b51a9e 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance = - device_contraction_f64_kk_instance; +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp index 0a36ef7ffd..fd3f57c6b6 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = - device_contraction_f64_kn_instance; +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp index 617332c2c7..1e53f0b2f5 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = - device_contraction_f64_mk_instance; +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp index 1df7f31201..d02d146a9b 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = - device_contraction_f64_mn_instance; +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp index fed2cbbfb9..b6312a6f51 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp @@ -26,7 +26,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< @@ -35,6 +35,9 @@ using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //Generic instance + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2>, + // DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp index 44ac4c08cd..0b1dcad6f2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp @@ -26,7 +26,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< @@ -35,6 +35,9 @@ using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //Generic instance + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2>, + // DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp index 9d15ccd362..1fcaad1ff7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -26,7 +26,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< @@ -35,6 +35,9 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //Generic instance + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 1>, + // DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp index 4e9ad58742..2b53de60d5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -26,7 +26,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< @@ -35,6 +35,9 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //Generic instances + DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 1>, + // DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index f44e97ed75..0927cf225d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -5,7 +5,7 @@ add_instance_library(device_grouped_conv3d_bwd_data_instance xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp new file mode 100644 index 0000000000..d46be53ba4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 6bd9d6e646..65ea645a39 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -4,7 +4,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) + device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) if(DL_KERNELS) list(APPEND GROUPED_CONV3D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp new file mode 100644 index 0000000000..7f9493f602 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 728e2d1c34..c3cc4cb054 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -4,10 +4,12 @@ add_instance_library(device_grouped_conv3d_fwd_instance xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp new file mode 100644 index 0000000000..431dfed50a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/README.md b/profiler/README.md index b7b5be420b..d081fc33e9 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -50,23 +50,21 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ## Profile contraction kernels ```bash #arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear) -#arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16) -#arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16) -#arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; +#arg2: data type (0: fp32; 1: f64)\n" +#arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]) -#arg5: verification (0: no; 1: yes) -#arg6: initialization (0: no init; 1: integer value; 2: decimal value) -#arg7: print tensor value (0: no; 1: yes) -#arg8: time kernel (0: no, 1: yes) -#arg9: alpha -#arg10: beta -#arg11 to 16: M0, M1, N0, N1, K0, K1 -#arg17 to 32: Strides for A, B, D and E (skip for default) +#arg4: verification (0: no; 1: yes) +#arg5: initialization (0: no init; 1: integer value; 2: decimal value) +#arg6: print tensor value (0: no; 1: yes) +#arg7: time kernel (0: no, 1: yes) +#arg8 and arg9: alpha and beta +#arg10 to 15: M0, M1, N0, N1, K0, K1 +#arg16 to 31: Strides for A, B, D and E (skip for default) -################ op datatype compute_datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 -./bin/ckProfiler contraction_bilinear 0 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 +################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 +./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 ``` Result (MI100) diff --git a/profiler/include/profiler/profile_contraction_impl.hpp b/profiler/include/profiler/profile_contraction_impl.hpp index f6e4b3f395..660cc3f9e5 100644 --- a/profiler/include/profiler/profile_contraction_impl.hpp +++ b/profiler/include/profiler/profile_contraction_impl.hpp @@ -31,14 +31,10 @@ namespace profiler { using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Scale = ck::tensor_operation::element_wise::Scale; -using F32 = float; -using F64 = double; - template int profile_contraction_impl(ck::index_t do_verification, @@ -49,10 +45,10 @@ int profile_contraction_impl(ck::index_t do_verification, const std::vector& M, const std::vector& N, const std::vector& K, - const std::vector& StridesA, // [M0, M1, K0, K1] - const std::vector& StridesB, // [N0, N1, K0, K1] - const std::vector& StridesE, // [M0, M1, N0, N1] - const std::vector& StridesD) // [M0, M1, N0, N1] + const std::vector& StridesA, + const std::vector& StridesB, + const std::vector& StridesE, + const std::vector& StridesD) { bool pass = true; @@ -67,13 +63,13 @@ int profile_contraction_impl(ck::index_t do_verification, }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StridesA)); - Tensor b_n_k(f_host_tensor_descriptor(N, K, StridesB)); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StridesB)); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor d_m_n(f_host_tensor_descriptor(M, N, StridesD)); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_n_k: " << b_n_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; @@ -82,12 +78,12 @@ int profile_contraction_impl(ck::index_t do_verification, case 0: break; case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } @@ -95,12 +91,12 @@ int profile_contraction_impl(ck::index_t do_verification, using BElementOp = ck::tensor_operation::element_wise::PassThrough; DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(DataType) * b_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize()); a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_n_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); e_device_buf.SetZero(); d_device_buf.ToDevice(d_m_n.mData.data()); @@ -122,8 +118,7 @@ int profile_contraction_impl(ck::index_t do_verification, DataType, AElementOp, BElementOp, - CDElementOp, - ComputeDataType>; + CDElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -131,9 +126,6 @@ int profile_contraction_impl(ck::index_t do_verification, std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - using AccDataType = - typename std::conditional::value, F64, F32>::type; - // Run reference op if(do_verification) { @@ -144,8 +136,7 @@ int profile_contraction_impl(ck::index_t do_verification, DataType, DataType, DataType, - AccDataType, - ComputeDataType, + DataType, AElementOp, BElementOp>; @@ -155,7 +146,7 @@ int profile_contraction_impl(ck::index_t do_verification, Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); auto ref_argument = - ref_op.MakeArgument(a_m_k, b_n_k, c_m_n_host_result, a_element_op, b_element_op); + ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); @@ -281,29 +272,8 @@ int profile_contraction_impl(ck::index_t do_verification, { e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - // Both the kernel and the reference use `AccDataType`, so an absolute error of both - // of them is bounded by `nelems_k * std::numeric_limits::epsilon()`. - // Comparing one to another can result in an absolute error as high as twice that - // value. - double threshold = 2 * nelems_k * std::numeric_limits::epsilon(); - // Handle the possible casting error of either AccDataType -> DataType or - // DataType -> ComputeDataType. - // TODO: Add a generic solution for calculating thresholds in CK. - if constexpr(ck::is_same_v || - ck::is_same_v) - { - const double epsilon = std::pow(2, -7); - // Maximum relative casting error when rounding to zero. - threshold += epsilon * 2; - } - else if constexpr(ck::is_same_v || - ck::is_same_v) - { - const double epsilon = std::pow(2, -10); - // Maximum relative casting error when rounding to zero. - threshold += epsilon * 2; - } - + float threshold = + static_cast(nelems_k) * std::numeric_limits::epsilon(); pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result, "Error: incorrect results!", @@ -313,7 +283,7 @@ int profile_contraction_impl(ck::index_t do_verification, if(do_log) { LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_n_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_host : ", e_m_n_host_result.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_device: ", e_m_n_device_result.mData, ",") diff --git a/profiler/include/profiler/profile_contraction_utils.hpp b/profiler/include/profiler/profile_contraction_utils.hpp index 05ec7daf9d..076bbd4559 100644 --- a/profiler/include/profiler/profile_contraction_utils.hpp +++ b/profiler/include/profiler/profile_contraction_utils.hpp @@ -23,18 +23,8 @@ enum struct ContractionMatrixLayout enum struct ContractionDataType { - F32_F32_F32_F32, // 0 - F64_F64_F64_F64, // 1 - F16_F16_F16_F16, // 2 - BF16_BF16_BF16_BF16, // 3 -}; - -enum struct ContractionComputeDataType -{ - F32 = 0, - F64, - F16, - BF16, + F32_F32_F32_F32, // 0 + F64_F64_F64_F64, // 1 }; inline void collect_index_params(char* argv[], diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 48bf639a70..0300aed436 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -33,7 +33,9 @@ template + typename OutDataType, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> bool profile_grouped_conv_bwd_weight_impl(int do_verification, int init_method, bool do_log, @@ -120,7 +122,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>; + OutElementOp, + ComputeTypeA, + ComputeTypeB>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/profiler/src/profile_contraction_bilinear.cpp b/profiler/src/profile_contraction_bilinear.cpp index 8cb34ccc6c..6ed1841204 100644 --- a/profiler/src/profile_contraction_bilinear.cpp +++ b/profiler/src/profile_contraction_bilinear.cpp @@ -17,9 +17,8 @@ static void print_helper_msg() { std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" - << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " + << "arg2: data type (0: fp32; 1: f64)\n" + << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" @@ -27,42 +26,40 @@ static void print_helper_msg() "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" - << "arg5: verification (0: no; 1: yes)\n" - << "arg6: initialization (0: no init; 1: integer value; 2: decimal " + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" - << "arg7: print tensor value (0: no; 1: yes)\n" - << "arg8: time kernel (0: no, 1: yes)\n" - << "arg9: alpha\n" - << "arg10: beta\n" - << "arg11 to 16: M0, M1, N0, N1, K0, K1\n" - << "arg17 to 32: Strides for A, B, D and E (skip for default)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8 and arg9: alpha and beta\n" + << "arg10 to 15: M0, M1, N0, N1, K0, K1\n" + << "arg16 to 31: Strides for A, B, D and E (skip for default)\n" << std::endl; } int profile_contraction_bilinear(int argc, char* argv[]) { - const bool default_strides = argc == 17; + const bool default_strides = argc == 16; - if(argc != 33 && argc != 17) + if(argc != 32 && argc != 16) { print_helper_msg(); exit(1); } const auto data_type = static_cast(std::stoi(argv[2])); - const auto compute_data_type = static_cast(std::stoi(argv[3])); - const auto layout = static_cast(std::stoi(argv[4])); - const bool do_verification = std::stoi(argv[5]); - const ck::index_t init_method = std::stoi(argv[6]); - const bool do_log = std::stoi(argv[7]); - const bool time_kernel = std::stoi(argv[8]); - const float alpha = std::stof(argv[9]); - const float beta = std::stof(argv[10]); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const ck::index_t init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const float alpha = std::stof(argv[8]); + const float beta = std::stof(argv[9]); std::vector M; std::vector N; std::vector K; - const ck::index_t dims_arg_num = 11; + const ck::index_t dims_arg_num = 10; collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, K, dims_arg_num + 4, 2); @@ -79,130 +76,90 @@ int profile_contraction_bilinear(int argc, char* argv[]) collect_index_params(argv, StridesD, dims_arg_num + 18, 4); } - using F16 = ck::half_t; - using BF16 = ck::bhalf_t; - using F32 = float; - using F64 = double; + using F32 = float; + using F64 = double; - auto profile = - [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) { - using ALayout = decltype(a_layout); - using BLayout = decltype(b_layout); - using CDELayout = decltype(cde_layout); + auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CDELayout = decltype(cde_layout); - using DataType = decltype(type); - using ComputeDataType = decltype(compute_type); + using DataType = decltype(type); - if(default_strides) - { - assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); - assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); - } - bool pass = ck::profiler::profile_contraction_impl, - Bilinear>(do_verification, - init_method, - do_log, - time_kernel, - Bilinear{alpha, beta}, - M, - N, - K, - StridesA, - StridesB, - StridesE, - StridesD); - - return pass; - }; - - auto run_profile_for_datatype = [&](auto type, auto compute_type) { - if(layout == ContractionMatrixLayout::MK_KN_MN_MN) + if(default_strides) { - return profile(Row{}, Row{}, Row{}, type, compute_type); + assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); + assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); } - else if(layout == ContractionMatrixLayout::MK_NK_MN_MN) - { - return profile(Row{}, Col{}, Row{}, type, compute_type); - } - else if(layout == ContractionMatrixLayout::KM_KN_MN_MN) - { - return profile(Col{}, Row{}, Row{}, type, compute_type); - } - else if(layout == ContractionMatrixLayout::KM_NK_MN_MN) - { - return profile(Col{}, Col{}, Row{}, type, compute_type); - } - return false; + bool pass = ck::profiler::profile_contraction_impl, + Bilinear>(do_verification, + init_method, + do_log, + time_kernel, + Bilinear{alpha, beta}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesD); + + return pass; }; - if(data_type == ContractionDataType::F32_F32_F32_F32) + if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F32{}, F32{}); - } - else if(compute_data_type == ContractionComputeDataType::F16) - { - return run_profile_for_datatype(F32{}, F16{}); - } - else if(compute_data_type == ContractionComputeDataType::BF16) - { - return run_profile_for_datatype(F32{}, BF16{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Row{}, Row{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::F64_F64_F64_F64) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F64) - { - return run_profile_for_datatype(F64{}, F64{}); - } - else if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F64{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Row{}, Col{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::F16_F16_F16_F16) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F16{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Col{}, Row{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(BF16{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Col{}, Col{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) + { + return profile(Row{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) + { + return profile(Row{}, Col{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) + { + return profile(Col{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) + { + return profile(Col{}, Col{}, Row{}, F64{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; } - return 1; } REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear); diff --git a/profiler/src/profile_contraction_scale.cpp b/profiler/src/profile_contraction_scale.cpp index ca9c199988..6784b916f6 100644 --- a/profiler/src/profile_contraction_scale.cpp +++ b/profiler/src/profile_contraction_scale.cpp @@ -17,9 +17,8 @@ static void print_helper_msg() { std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" - << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " + << "arg2: data type (0: fp32; 1: f64)\n" + << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" @@ -27,40 +26,39 @@ static void print_helper_msg() "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" - << "arg5: verification (0: no; 1: yes)\n" - << "arg6: initialization (0: no init; 1: integer value; 2: decimal " + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" - << "arg7: print tensor value (0: no; 1: yes)\n" - << "arg8: time kernel (0: no, 1: yes)\n" - << "arg9: alpha\n" - << "arg10 to 15: M0, M1, N0, N1, K0, K1\n" - << "arg16 to 31: Strides for A, B, D and E (skip for default)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8: alpha\n" + << "arg9 to 14: M0, M1, N0, N1, K0, K1\n" + << "arg15 to 30: Strides for A, B, D and E (skip for default)\n" << std::endl; } int profile_contraction_scale(int argc, char* argv[]) { - const bool default_strides = argc == 16; + const bool default_strides = argc == 15; - if(argc != 32 && argc != 16) + if(argc != 31 && argc != 15) { print_helper_msg(); exit(1); } const auto data_type = static_cast(std::stoi(argv[2])); - const auto compute_data_type = static_cast(std::stoi(argv[3])); - const auto layout = static_cast(std::stoi(argv[4])); - const bool do_verification = std::stoi(argv[5]); - const ck::index_t init_method = std::stoi(argv[6]); - const bool do_log = std::stoi(argv[7]); - const bool time_kernel = std::stoi(argv[8]); - const float alpha = std::stof(argv[9]); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const ck::index_t init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const float alpha = std::stof(argv[8]); std::vector M; std::vector N; std::vector K; - const ck::index_t dims_arg_num = 10; + const ck::index_t dims_arg_num = 9; collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, K, dims_arg_num + 4, 2); @@ -77,131 +75,88 @@ int profile_contraction_scale(int argc, char* argv[]) collect_index_params(argv, StridesD, dims_arg_num + 18, 4); } - using F16 = ck::half_t; - using BF16 = ck::bhalf_t; - using F32 = float; - using F64 = double; + using F32 = float; + using F64 = double; - auto profile = - [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) { - using ALayout = decltype(a_layout); - using BLayout = decltype(b_layout); - using CDELayout = decltype(cde_layout); + auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CDELayout = decltype(cde_layout); - using DataType = decltype(type); - using ComputeDataType = decltype(compute_type); + using DataType = decltype(type); - if(default_strides) - { - assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); - assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); - } - - bool pass = ck::profiler::profile_contraction_impl, - Scale>(do_verification, - init_method, - do_log, - time_kernel, - Scale{alpha}, - M, - N, - K, - StridesA, - StridesB, - StridesE, - StridesD); - - return pass; - }; - - auto run_profile_for_datatype = [&](auto type, auto compute_type) { - if(layout == ContractionMatrixLayout::MK_KN_MN_MN) + if(default_strides) { - return profile(Row{}, Row{}, Row{}, type, compute_type); + assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); + assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); } - else if(layout == ContractionMatrixLayout::MK_NK_MN_MN) - { - return profile(Row{}, Col{}, Row{}, type, compute_type); - } - else if(layout == ContractionMatrixLayout::KM_KN_MN_MN) - { - return profile(Col{}, Row{}, Row{}, type, compute_type); - } - else if(layout == ContractionMatrixLayout::KM_NK_MN_MN) - { - return profile(Col{}, Col{}, Row{}, type, compute_type); - } - return false; + + bool pass = ck::profiler:: + profile_contraction_impl, Scale>( + do_verification, + init_method, + do_log, + time_kernel, + Scale{alpha}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesD); + + return pass; }; - if(data_type == ContractionDataType::F32_F32_F32_F32) + if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F32{}, F32{}); - } - else if(compute_data_type == ContractionComputeDataType::F16) - { - return run_profile_for_datatype(F32{}, F16{}); - } - else if(compute_data_type == ContractionComputeDataType::BF16) - { - return run_profile_for_datatype(F32{}, BF16{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Row{}, Row{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::F64_F64_F64_F64) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F64) - { - return run_profile_for_datatype(F64{}, F64{}); - } - else if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F64{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Row{}, Col{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::F16_F16_F16_F16) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F16{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Col{}, Row{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(BF16{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Col{}, Col{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) + { + return profile(Row{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) + { + return profile(Row{}, Col{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) + { + return profile(Col{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) + { + return profile(Col{}, Col{}, Row{}, F64{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; } - return 1; } REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale); diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index be8a3230f7..9a9b5a0f74 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -20,9 +20,10 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_F32_BF16, // 2 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_F32_BF16, // 2 + F16_F16_F16_BF8_F8 // 3 }; #define OP_NAME "grouped_conv_bwd_weight" @@ -33,7 +34,8 @@ static void print_helper_msg() std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << " 1: Input fp16, Weight fp16, Output fp16\n" - << " 2: Input bf16, Weight fp32, Output bf16)\n" + << " 2: Input bf16, Weight fp32, Output bf16\n" + << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n" << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " "N, K, Ho, Wo]\n" << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " @@ -82,6 +84,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; +#ifdef CK_ENABLE_FP8 + using F8 = ck::f8_t; +#endif +#ifdef CK_ENABLE_BF8 + using BF8 = ck::bf8_t; +#endif using namespace ck::tensor_layout::convolution; @@ -95,7 +103,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) auto out_layout, auto in_type, auto wei_type, - auto out_type) { + auto out_type, + auto compute_type_a, + auto compute_type_b) { constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; using InLayout = decltype(in_layout); @@ -106,13 +116,18 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using WeiDataType = decltype(wei_type); using OutDataType = decltype(out_type); + using ComputeTypeA = decltype(compute_type_a); + using ComputeTypeB = decltype(compute_type_b); + bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl( + OutDataType, + ComputeTypeA, + ComputeTypeB>( do_verification, init_method, do_log, time_kernel, params, split_k); return pass ? 0 : 1; @@ -122,80 +137,84 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F16_F16_F16_BF8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 07dd675afc..4b2c7bbf38 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -32,7 +32,7 @@ function(add_test_executable TEST_NAME) set(test 0) break() elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND NOT(source MATCHES type OR source MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -61,7 +61,7 @@ function(add_test_executable TEST_NAME) set(result 0) endif() #message("add_test returns ${result}") - return(PROPAGATE result) + set(result ${result} PARENT_SCOPE) endfunction(add_test_executable TEST_NAME) include(GoogleTest) @@ -91,7 +91,7 @@ function(add_gtest_executable TEST_NAME) set(test 0) break() elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND NOT(source MATCHES type OR source MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -123,7 +123,7 @@ function(add_gtest_executable TEST_NAME) set(result 0) endif() #message("add_gtest returns ${result}") - return(PROPAGATE result) + set(result ${result} PARENT_SCOPE) endfunction(add_gtest_executable TEST_NAME) add_subdirectory(magic_number_division) diff --git a/test/contraction/test_contraction.cpp b/test/contraction/test_contraction.cpp index 958d5be38c..c86b849235 100644 --- a/test/contraction/test_contraction.cpp +++ b/test/contraction/test_contraction.cpp @@ -10,12 +10,9 @@ #include #include "profiler/profile_contraction_impl.hpp" -#include "profiler/profile_contraction_utils.hpp" -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; -using F64 = double; +using F32 = float; +using F64 = double; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -23,49 +20,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Scale = ck::tensor_operation::element_wise::Scale; -struct Dimensions +struct MemoryParams { std::vector M; std::vector N; std::vector K; + std::vector StridesA; + std::vector StridesB; + std::vector StridesC; + std::vector StridesD; }; template class TestContraction : public ::testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using CDLayout = std::tuple_element_t<2, Tuple>; - using DataType = std::tuple_element_t<3, Tuple>; - using DTupleDataType = std::tuple_element_t<4, Tuple>; - using ComputeDataType = std::tuple_element_t<5, Tuple>; - using CDElementOp = std::tuple_element_t<6, Tuple>; + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CDLayout = std::tuple_element_t<2, Tuple>; + using DataType = std::tuple_element_t<3, Tuple>; + using DTupleDataType = std::tuple_element_t<4, Tuple>; + using CDElementOp = std::tuple_element_t<5, Tuple>; - std::vector dimension_list = {{{32, 32}, {32, 32}, {32, 32}}, - {{16, 16}, {32, 32}, {16, 16}}}; + std::vector list_of_memory_params = {{{32, 32}, + {32, 32}, + {32, 32}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}}, + {{16, 16}, + {32, 32}, + {16, 16}, + {4096, 256, 16, 1}, + {16, 1, 8192, 256}, + {16384, 1024, 32, 1}, + {16384, 1024, 32, 1}}}; - std::vector init_methods = {1, 2}; + std::vector init_methods = {0, 1, 2}; std::unique_ptr p_cd_element_op; - void Run() { - for(auto& dimension_params : dimension_list) + for(auto& memory_params : list_of_memory_params) { - std::vector StridesA; - std::vector StridesB; - std::vector StridesC; - std::vector StridesD; - - const auto& M = dimension_params.M; - const auto& N = dimension_params.N; - const auto& K = dimension_params.K; - - assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]}); - assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]}); - for(const ck::index_t init_method : init_methods) { bool pass = @@ -73,20 +70,19 @@ class TestContraction : public ::testing::Test BLayout, CDLayout, DataType, - ComputeDataType, DTupleDataType, CDElementOp>(true /*do_verification*/, init_method, false /*do_logs*/, false /*time_kernel*/, *p_cd_element_op, - dimension_params.M, - dimension_params.N, - dimension_params.K, - StridesA, - StridesB, - StridesC, - StridesD); + memory_params.M, + memory_params.N, + memory_params.K, + memory_params.StridesA, + memory_params.StridesB, + memory_params.StridesC, + memory_params.StridesD); EXPECT_TRUE(pass); } } @@ -103,18 +99,24 @@ class TestContractionBilinear : public TestContraction { }; -#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \ - std::tuple, \ - std::tuple, \ - std::tuple, \ - std::tuple - using BilinearKernelTypes = - ::testing::Types, F32, Bilinear), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple, F64, Bilinear)>; + ::testing::Types, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>>; -using ScaleKernelTypes = ::testing::Types, F32, Scale), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>; +using ScaleKernelTypes = ::testing::Types, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>>; TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); @@ -134,46 +136,3 @@ TYPED_TEST(TestContractionScale, scale) this->p_cd_element_op = std::make_unique(0.5f); this->Run(); } - -template -class TestContractionScaleMixedPrecision : public TestContraction -{ -}; - -template -class TestContractionBilinearMixedPrecision : public TestContraction -{ -}; - -using BilinearKernelTypesMixedPrecision = - ::testing::Types, F16, Bilinear), - ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple, BF16, Bilinear), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple, F32, Bilinear), - ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple, F32, Bilinear), - ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple, F32, Bilinear)>; - -using ScaleKernelTypesMixedPrecision = - ::testing::Types, F16, Scale), - ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale), - ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale), - ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>; - -TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision); -TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision); - -TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear) -{ - this->p_cd_element_op = std::make_unique(1.f, 1.f); - this->Run(); - this->p_cd_element_op = std::make_unique(-0.5f, 0.5f); - this->Run(); -} - -TYPED_TEST(TestContractionScaleMixedPrecision, scale) -{ - this->p_cd_element_op = std::make_unique(1.f); - this->Run(); - this->p_cd_element_op = std::make_unique(0.5f); - this->Run(); -} diff --git a/test/contraction/test_contraction_interface.cpp b/test/contraction/test_contraction_interface.cpp index e0f547445e..12a307f5da 100644 --- a/test/contraction/test_contraction_interface.cpp +++ b/test/contraction/test_contraction_interface.cpp @@ -34,11 +34,11 @@ class ContractionInstanceWrapper static constexpr ck::index_t NumDim = 2; // clang-format off using ContractionDeviceInstance = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple, F32, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; // clang-format on bool isSupported(std::vector& ADims,