diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 8d2cbb1c70..d262293440 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,7 +1,6 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp) -add_example_executable(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle_v1 gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle_v1.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp16_bpreshuffle gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) @@ -12,7 +11,6 @@ list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) # list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental") target_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) target_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) -target_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle_v1 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp index 40a8a01b24..268e808c5a 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp @@ -97,10 +97,10 @@ using DeviceOpInstance = A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - 128, 128, + 64, 128, 128, 16, 16, 16, 16, - 8, 2, + 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, @@ -238,6 +238,16 @@ int main(int argc, char* argv[]) b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); } + // printf("a1_m_k: \n"); + // for(int i = 0; i < (M + Scale_Block_M - 1) / Scale_Block_M; ++i) + // { + // for(int j = 0; j < (K + Scale_Block_K - 1) / Scale_Block_K; ++j) + // { + // printf("%f ", a1_m_k(i, j)); + // } + // printf("\n"); + // } + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle_v1.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle_v1.cpp deleted file mode 100644 index 16f7a79367..0000000000 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle_v1.cpp +++ /dev/null @@ -1,382 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" - -#include "ck/utility/blkgemmpipe_scheduler.hpp" - -template -using S = ck::Sequence; - -using BF16 = ck::bhalf_t; -using FP8 = ck::f8_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using A0DataType = FP8; -using A1DataType = F32; -using B0DataType = FP8; -using B1DataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DsDataType = ck::Tuple<>; -using EDataType = BF16; - -using A0Layout = Row; -using B0Layout = Col; -using D0Layout = Row; -using D1Layout = Col; -using DsLayout = ck::Tuple<>; -using ELayout = Row; - -void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) -{ - int KPack = 16; - int NLane = NXdl; - int KLane = 64 / NLane; - - int K0 = K / (KLane * KPack); - // K -> K0 KLane KPack - // N -> N0 NLane - // N, K -> N0 K0 KLane NLane KPack - int tempk; - for(int n = 0; n < N; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / NLane; - int n1 = n % NLane; - - int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); - int k1 = tempk / KPack; - int k2 = tempk % KPack; - - int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + - k1 * KPack * NLane + n1 * KPack + k2; - - dst[outputIndex] = src[n * K + k]; - } - } -} -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = PassThrough; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - -static constexpr ck::index_t Scale_Block_M = 1; -static constexpr ck::index_t Scale_Block_N = 128; -static constexpr ck::index_t Scale_Block_K = 128; - -using DeviceOpInstance = - ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle - // clang-format off - , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 16, 1, 16>, S<8>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; -// clang-format on - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - bool flush_cache = true; - - // GEMM shape - ck::index_t M = 128; - ck::index_t N = 1024; - ck::index_t K = 1024; - - ck::index_t StrideA = K; - ck::index_t StrideB = K; - ck::index_t StrideE = N; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 8) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - flush_cache = std::stoi(argv[7]); - - StrideA = K; - StrideB = K; - StrideE = N; - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 6: M, N, K\n"); - printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); - exit(0); - } - - ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; - ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); - Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, - (K + Scale_Block_K - 1) / Scale_Block_K, - Scale_Stride_AM, - A0Layout{})); - Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b0_preshuffled( - f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size - Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, - (N + Scale_Block_N - 1) / Scale_Block_N, - Scale_Stride_BN, - B0Layout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; - std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; - std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; - std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - -#if 1 - switch(init_method) - { - case 0: break; - case 1: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - break; - case 2: - a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); - a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 3: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 4: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 5: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - break; - default: - a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - } -#endif -#if 0 - for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){ - float row_sum = .0; - for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){ - printf("%lf ",a1_m_k(im, ik)); - row_sum += a1_m_k(im, ik); - } - printf("sum: %lf\n", row_sum * 128); - } -#endif - - DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); - DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); - DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a0_device_buf.ToDevice(a0_m_k.mData.data()); - a1_device_buf.ToDevice(a1_m_k.mData.data()); - b1_device_buf.ToDevice(b1_k_n.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - constexpr ck::index_t NumDTensor = DsDataType::Size(); - - // do GEMM - auto device_op = DeviceOpInstance{}; - int NPerXdl = device_op.GetPreShuffleParameters(); - - preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); - - b0_device_buf.ToDevice(b0_preshuffled.mData.data()); - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), - b0_device_buf.GetDeviceBuffer(), - std::array{}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{}, - StrideE, - a1_device_buf.GetDeviceBuffer(), - b1_device_buf.GetDeviceBuffer(), - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; - - float ave_time = .0; - - if(flush_cache) - { - int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; - - ave_time = invoker.Run(argument, - StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); - } - else - { - ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); - } - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - if(do_verification) - { - Tensor c_m_n({M, N}); - Tensor a_m_k({M, K}); - Tensor b_k_n({K, N}); - - for(int m = 0; m < M; m++) - { - for(int k = 0; k < K; k++) - { - a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * - a1_m_k(m / Scale_Block_M, k / Scale_Block_K); - } - } - - for(int n = 0; n < N; n++) - { - for(int k = 0; k < K; k++) - { - b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * - b1_k_n(k / Scale_Block_K, n / Scale_Block_N); - } - } - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); - - ref_invoker.Run(ref_argument); - -#if 1 - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); - } - } -#endif - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err( - e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) - ? 0 - : 1; - } - - return 0; -} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp index 07496e6db4..555dbbf058 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp @@ -133,6 +133,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -389,6 +390,9 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{})); } + // printf("Tid: %d| a_scale_thread_buf: %f %f\n", get_thread_local_1d_id(), + // a_scale_thread_buf[Number<0>{}], + // a_scale_thread_buf[Number<1>{}]); b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -409,12 +413,58 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); +#if defined(__gfx950__) && 0 + printf( + "Tid: %02d, a_thread_buf: %02x %02x %02x %02x %02x %02x %02x %02x| %02x " + "%02x %02x %02x %02x %02x %02x %02x| %02x %02x %02x %02x %02x %02x %02x " + "%02x| %02x %02x %02x %02x %02x %02x %02x %02x|\n", + get_thread_local_1d_id(), + *(reinterpret_cast(&(a_thread_buf[Number<0>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<1>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<2>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<3>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<0 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<1 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<2 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<3 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<8 + 0>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<8 + 1>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<8 + 2>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<8 + 3>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<8 + 0 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<8 + 1 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<8 + 2 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<8 + 3 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 0>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 1>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 2>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 3>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 0 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 1 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 2 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 3 + 4>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 8 + 0>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 8 + 1>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 8 + 2>{}]))), + *(reinterpret_cast(&(a_thread_buf[Number<16 + 8 + 3>{}]))), + *(reinterpret_cast( + &(a_thread_buf[Number<16 + 8 + 0 + 4>{}]))), + *(reinterpret_cast( + &(a_thread_buf[Number<16 + 8 + 1 + 4>{}]))), + *(reinterpret_cast( + &(a_thread_buf[Number<16 + 8 + 2 + 4>{}]))), + *(reinterpret_cast( + &(a_thread_buf[Number<16 + 8 + 3 + 4>{}])))); +#endif + }); }); }); @@ -520,12 +570,15 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -683,12 +736,15 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -806,6 +862,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1(), c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -836,7 +893,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp index 3037a229f4..1a0b04a0e8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp @@ -535,12 +535,13 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); }); }); @@ -720,7 +721,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}), + Number{}), a_thread_buf); }); }); @@ -746,7 +747,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}), + Number{}), a_thread_buf); }); }); @@ -772,7 +773,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}), + Number{}), a_thread_buf); }); }); @@ -876,19 +877,22 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); }); } @@ -896,19 +900,22 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); }); } @@ -916,19 +923,22 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}([&](auto k0) { static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 2) % MRepeat>{}, - I0, - I0, - Number{}, - I0, - I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + 2) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); }); }); } @@ -1026,7 +1036,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}), + Number{}), a_thread_buf); }); }); @@ -1114,8 +1124,12 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}, I0, I0, Number{}, I0, I0), a_block_buf.At(I0), a_thread_desc_, - make_tuple( - Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + make_tuple(Number<(m0 + 2) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), a_thread_buf); }); }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index 9e1ae0ead7..87ed771017 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -1286,7 +1286,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle 0, 1, 1, - false>( + true>( a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); @@ -1300,7 +1300,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle 1, ScaleSliceSizeK, 1, - false>( + true>( b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); @@ -1788,7 +1788,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle 0, 1, 1, - false>( + true>( a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); @@ -1802,7 +1802,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle 1, ScaleSliceSizeK, 1, - false>( + true>( b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 2255505985..b8a52a9efa 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -361,6 +361,99 @@ struct ThreadwiseTensorSliceTransfer_v2 } } + template + __device__ void RunPrint(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over tensor and copy + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + typename vector_type_maker::type src_vector; + + using src_vector_t = + typename vector_type_maker::type::type; + constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + printf("Tid: %03d, Ascale read gmem src_data_coord.GetOffset() = %d\n", + get_thread_local_1d_id(), + src_coord_.GetOffset()); + // copy data from src_buf into src_vector + src_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset() / PackedSize, + is_src_valid); + + // copy data from src_vector into dst_buf + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + + i * src_scalar_step_in_vector); + + if constexpr(InvalidElementAsNaN) + { + dst_buf(Number{}) = + is_src_valid + ? type_convert(src_vector.template AsType()[i]) + : NumericLimits::QuietNaN(); + } + else + { + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); + } + }); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + } + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + __device__ static constexpr auto GetSrcCoordinateResetStep() { constexpr auto src_scalar_per_access = generate_sequence( @@ -1164,6 +1257,8 @@ struct ThreadwiseTensorSliceTransfer_v4 // copy data from src_buf into src_tmp_vector if constexpr(SrcBuffer::IsDynamicBuffer()) { + // printf("Tid: %03d, read lds src_data_coord.GetOffset() = %d\n", + // get_thread_local_1d_id(),src_data_coord.GetOffset()); src_tmp_vector.template AsType()(Number<0>{}) = src_buf.template Get(src_data_coord.GetOffset() / PackedSize, is_src_valid); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 7ccea96dda..c95894f960 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -627,6 +627,37 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); }); + // #if defined(__gfx950__) + // printf("Tid: %03d, a_gmem: %02x %02x %02x %02x %02x %02x + // %02x %02x|\n", + // get_thread_local_1d_id(), + // *(reinterpret_cast(&(dst_vector_container.template + // AsType()[Number<0>{}]))), + // *(reinterpret_cast(&(dst_vector_container.template + // AsType()[Number<1>{}]))), + // *(reinterpret_cast(&(dst_vector_container.template + // AsType()[Number<2>{}]))), + // *(reinterpret_cast(&(dst_vector_container.template + // AsType()[Number<3>{}]))), + // *(reinterpret_cast(&(dst_vector_container.template + // AsType()[Number<0 + 4>{}]))), + // *(reinterpret_cast(&(dst_vector_container.template + // AsType()[Number<1 + 4>{}]))), + // *(reinterpret_cast(&(dst_vector_container.template + // AsType()[Number<2 + 4>{}]))), + // *(reinterpret_cast(&(dst_vector_container.template + // AsType()[Number<3 + 4>{}])))); + // #endif + // printf("Tid: %03d, write to dst_coord_.GetOffset(): %d\n", + // get_thread_local_1d_id(), dst_coord_.GetOffset() / PackedSize); // copy data from dst_vector_container to dst_buf dst_buf.template Set( dst_coord_.GetOffset() / PackedSize, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp index 1a75db60e4..d9ff81a863 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp @@ -152,13 +152,13 @@ struct DeviceOperationInstanceFactory< { add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( op_ptrs); - add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( - op_ptrs); + // add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + // op_ptrs); add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( op_ptrs); - add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( - op_ptrs); + // add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + // op_ptrs); } } #endif diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt index f13ab883a1..57cbd725aa 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt @@ -8,9 +8,9 @@ list(APPEND GEMM_BLOCKSCALE_WP_INSTANCES device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp ) -set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") +set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") +set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") +set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") add_instance_library(device_gemm_blockscale_wp_instance ${GEMM_BLOCKSCALE_WP_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp index a0c95cf2ab..55726bb915 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -42,8 +42,7 @@ using device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Compute friendly - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on @@ -52,32 +51,36 @@ using device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances template using device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< // clang-format off - //################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //#######################################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#######################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#######################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // Memory friendly - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8> + // Memory friendly + // 16x + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + //32x + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + //48x + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 256, 128, 8, 16, 16, 16, 3, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 128, 128, 8, 16, 16, 16, 3, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 64, 128, 8, 16, 16, 16, 3, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 128, 256, 16, 16, 16, 16, 3, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 48, 64, 256, 16, 16, 16, 16, 3, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + //64x + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8> // clang-format on >; } // namespace instance diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index e3844b1ef7..1e0e61ecaa 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -99,9 +99,7 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, } }; - ck::index_t Scale_Stride_AM = ck::is_same_v - ? ((K + ScaleBlockK - 1) / ScaleBlockK) - : ((M + ScaleBlockM - 1) / ScaleBlockM); + ck::index_t Scale_Stride_AM = ((M + ScaleBlockM - 1) / ScaleBlockM); ck::index_t Scale_Stride_BN = ck::is_same_v ? ((K + ScaleBlockK - 1) / ScaleBlockK) : ((N + ScaleBlockN - 1) / ScaleBlockN); @@ -110,7 +108,7 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, Tensor a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM, (K + ScaleBlockK - 1) / ScaleBlockK, Scale_Stride_AM, - ALayout{})); + ck::tensor_layout::gemm::ColumnMajor{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor b_preshuffled_mfma16( f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size @@ -303,16 +301,25 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, if constexpr(is_same_v || is_same_v || is_same_v) { - std::string msg = "Error: Incorrect results!"; - double rtol = 5e-2; - double atol = 5e-2; - pass = pass & ck::utils::check_err( - e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); + std::string msg = "Error: Incorrect results!"; + double rtol = 5e-2; + double atol = 5e-2; + bool current_pass = ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); + pass = pass & current_pass; + if(!current_pass) + { + std::cout << op_ptr->GetTypeString() << " failed" << std::endl; + } } else { #endif pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + if(!pass) + { + std::cout << op_ptr->GetTypeString() << " failed" << std::endl; + } #if defined CK_ENABLE_FP8 } #endif diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index f3fd3b8d2e..f02811947a 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -1,90 +1,90 @@ # ckProfiler set(PROFILER_SOURCES profiler.cpp - profile_gemm.cpp - profile_reduce.cpp - profile_groupnorm_bwd_data.cpp - profile_groupnorm_fwd.cpp - profile_layernorm_bwd_data.cpp - profile_layernorm_bwd_gamma_beta.cpp - profile_groupnorm_bwd_gamma_beta.cpp - profile_layernorm_fwd.cpp - profile_max_pool2d_fwd.cpp - profile_pool3d_fwd.cpp - profile_avg_pool3d_bwd.cpp - profile_max_pool3d_bwd.cpp - profile_avg_pool2d_bwd.cpp - profile_max_pool2d_bwd.cpp - profile_softmax.cpp - profile_batchnorm_fwd.cpp - profile_batchnorm_bwd.cpp - profile_batchnorm_infer.cpp - profile_conv_tensor_rearrange.cpp - profile_transpose.cpp - profile_permute_scale.cpp + # profile_gemm.cpp + # profile_reduce.cpp + # profile_groupnorm_bwd_data.cpp + # profile_groupnorm_fwd.cpp + # profile_layernorm_bwd_data.cpp + # profile_layernorm_bwd_gamma_beta.cpp + # profile_groupnorm_bwd_gamma_beta.cpp + # profile_layernorm_fwd.cpp + # profile_max_pool2d_fwd.cpp + # profile_pool3d_fwd.cpp + # profile_avg_pool3d_bwd.cpp + # profile_max_pool3d_bwd.cpp + # profile_avg_pool2d_bwd.cpp + # profile_max_pool2d_bwd.cpp + # profile_softmax.cpp + # profile_batchnorm_fwd.cpp + # profile_batchnorm_bwd.cpp + # profile_batchnorm_infer.cpp + # profile_conv_tensor_rearrange.cpp + # profile_transpose.cpp + # profile_permute_scale.cpp ) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) - endif() - list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) - if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") - list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_wp.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) + # if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + # list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) + # list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) + # endif() + # if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + # list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) + # list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) + # list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) + # list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) + # list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + # list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) + # list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) + # list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) + # endif() + # list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) + # if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") + # list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_wp.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) list(APPEND PROFILER_SOURCES profile_gemm_blockscale_wp.cpp) - endif() - list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) - list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) + # endif() + # list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) + # list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp) + # list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) + # list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) + # list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) + # list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) + # list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) + # list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) - endif() - list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) -endif() +# if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) +# endif() +# list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +# endif() -if(DL_KERNELS) - list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) -endif() +# if(DL_KERNELS) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +# endif() set(PROFILER_EXECUTABLE ckProfiler) @@ -97,91 +97,91 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) endif() target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool2d_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool2d_bwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) - endif() - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) - if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_wp_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) + # if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) + # endif() + # if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) + # endif() + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) + # if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_wp_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_blockscale_wp_instance) - endif() - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) + # endif() + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) + # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) - endif() - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -endif() +# if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) +# endif() +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +# endif() -if(DL_KERNELS) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -endif() +# if(DL_KERNELS) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +# endif() rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/profiler/src/profile_gemm_blockscale_wp.cpp b/profiler/src/profile_gemm_blockscale_wp.cpp index 01df933f7d..d413659f71 100644 --- a/profiler/src/profile_gemm_blockscale_wp.cpp +++ b/profiler/src/profile_gemm_blockscale_wp.cpp @@ -35,7 +35,7 @@ enum struct ScaleBlockTile Tile_1_128_128, // 1 }; -#define OP_NAME "gemm_blockscale_weighpreshuffle" +#define OP_NAME "gemm_blockscale_wp" #define OP_DESC "GEMM_BlockScale_WeightPreshuffle" int profile_gemm_blockscale_weighpreshuffle(int argc, char* argv[])