diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2d65368d4f..aba462638e 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -149,3 +149,7 @@ add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) +add_example_executable(example_gemm_wmma_fp8_bpreshuffle gemm_wmma_fp8_bpreshuffle.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_bpreshuffle) +add_example_executable(example_gemm_wmma_fp16_bpreshuffle gemm_wmma_fp16_bpreshuffle.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_bpreshuffle) diff --git a/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp b/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp new file mode 100644 index 0000000000..d03971e6ec --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/scheduler_enum.hpp" + +#include +#include +#include + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using ComputeTypeA = F16; +using ComputeTypeB = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = false; +static constexpr int KPack = 8; // int4 -> 32, fp8 -> 16, fp16 -> 8 +// clang-format off +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 128, + 32, 128, 128, + 8, 8, + 16, 16, + 2, 2, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>; +// clang-format on + +#include "run_gemm_wmma_bpreshuffle_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp b/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp new file mode 100644 index 0000000000..8f8b380b93 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/scheduler_enum.hpp" + +#include +#include +#include + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F8; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using ComputeTypeA = F8; +using ComputeTypeB = F8; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = false; +static constexpr int KPack = 16; // int4 -> 32, fp8 -> 16, fp16 -> 8 +// clang-format off +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 32, 128, 256, + 16, 16, + 16, 16, + 2, 1, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>; +// clang-format on + +#include "run_gemm_wmma_bpreshuffle_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc b/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc new file mode 100644 index 0000000000..b1d73cfe10 --- /dev/null +++ b/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc @@ -0,0 +1,206 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_preshuffled: " << b_k_n_preshuffled.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // do GEMM + auto device_op = DeviceOpInstance{}; + + // weight pre-shuffle + int NPerWmma = device_op.GetPreShuffleParameters(); + int KLane = ck::get_warp_size() / NPerWmma; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NPerWmma + // N, K -> N0 K0 KLane NPerWmma KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NPerWmma; + int n1 = n % NPerWmma; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NPerWmma * KLane * K0 + k0 * KPack * NPerWmma * KLane + + k1 * KPack * NPerWmma + n1 * KPack + k2; + + b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k); + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + std::cerr << device_op.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + float ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 50, false, 1}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size{3840, 4096, 4096, 4096, 4096, 4096, 1}; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp new file mode 100644 index 0000000000..87bca24448 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp @@ -0,0 +1,303 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/tuple.hpp" + +#include +#include +#include +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3_BPreshuffle + : public DeviceGemmV2BPreshuffle +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + CLayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB, + true>; + + using Argument = typename GridwiseGemm::Argument; + + int GetPreShuffleParameters() override { return NPerWmma; } + + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + Tuple<>, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true>; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{std::array{p_a}, + std::array{p_b}, + std::array{}, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(std::array{p_a}, + std::array{p_b}, + std::array{}, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_BPreshuffle_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x" << NPerWmma << ", " + << "WaveMap: " + << MRepeat << "x" << NRepeat << ", " + << "VmemReadVec: " + << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << "BlkGemmPipelinePrefetchStages: " + << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", " + << "Kpack: " + << GridwiseGemm::KPack; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp index d8d1776a44..1a5709854c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp @@ -3,18 +3,19 @@ #pragma once -#include -#include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#ifdef CK_USE_XDL +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#if defined(CK_USE_XDL) || defined(CK_USE_WMMA) #include "gemm_universal_preshuffle.inc" #endif +#include +#include + namespace ck { namespace tensor_operation { namespace device { @@ -51,7 +52,7 @@ struct DeviceOperationInstanceFactory< static auto GetInstances() { -#ifdef CK_USE_XDL +#if defined(CK_USE_XDL) || defined(CK_USE_WMMA) std::vector> op_ptrs; #if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) if constexpr(is_same_v && is_same_v && @@ -60,6 +61,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( op_ptrs); add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_compute_instances( @@ -90,6 +92,17 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part1( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4( + op_ptrs); +#endif } } #endif @@ -100,6 +113,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( op_ptrs); add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( @@ -136,10 +150,21 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3( + op_ptrs); + add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4( + op_ptrs); +#endif } } #endif -#endif // CK_USE_XDL +#endif // CK_USE_XDL || CK_USE_WMMA return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc index b983913953..4f61958f34 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc @@ -13,8 +13,7 @@ namespace instance { using GemmF8F8BF16InstanceVector = std::vector>>&; -using GemmF8F8F16InstanceVector = std::vector>>&; +#ifdef CK_USE_XDL void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( GemmF8F8BF16InstanceVector& instances); @@ -61,7 +60,32 @@ void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp GemmF8F8BF16InstanceVector& instances); #endif + +#ifdef CK_USE_WMMA + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4( + GemmF8F8BF16InstanceVector& instances); + +#endif + +#endif + #if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) + +using GemmF8F8F16InstanceVector = std::vector>>&; + +#ifdef CK_USE_XDL + void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( @@ -99,6 +123,25 @@ void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_ GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6( GemmF8F8F16InstanceVector& instances); + +#endif + +#ifdef CK_USE_WMMA + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1( + GemmF8F8F16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2( + GemmF8F8F16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3( + GemmF8F8F16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4( + GemmF8F8F16InstanceVector& instances); + +#endif + #endif } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt index a022b746ac..c8fc544c83 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_INSTANCES) # F8_F8_BF16 @@ -21,6 +21,10 @@ device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshu device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_comp_instance.cpp +device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp +device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp +device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp +device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp ) # F8_F8_F16 @@ -43,6 +47,10 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp + device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp + device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp + device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp + device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp ) # F8_F8_F16 @@ -64,6 +72,10 @@ set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/devic set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") # F8_F8_BF16 set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") @@ -81,5 +93,9 @@ set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/devi set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_universal_preshuffle_instance ${GEMM_UNIVERSAL_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp new file mode 100644 index 0000000000..dd56980f0a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" + +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto v1 = BlockGemmPipelineVersion::v1; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p1 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 8, 1, 32>, S<4, 4, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p2 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 128, 16, 16, 16, 16, 16, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p3 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p4 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 1, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp new file mode 100644 index 0000000000..e7e43db376 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p1.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p1( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p1{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp new file mode 100644 index 0000000000..240548279c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p2.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p2{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp new file mode 100644 index 0000000000..af936b3924 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p3.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p3( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p3{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp new file mode 100644 index 0000000000..019f27e01a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_bf16/device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn_default_instance_p4.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_bf16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_default_instances_p4( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_bf16_mk_wmma_mn_instances_p4{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp new file mode 100644 index 0000000000..b2b823d3bd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" + +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto v1 = BlockGemmPipelineVersion::v1; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p1 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 8, 1, 32>, S<4, 4, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p2 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 128, 16, 16, 16, 16, 16, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p3 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p4 = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| ComputeTypeA| + //#####################################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| | | | + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 >, + DeviceGemm_Wmma_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 1, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, Intrawave, v1, F8 > + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp new file mode 100644 index 0000000000..c1dc5f263b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p1.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p1( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p1{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp new file mode 100644 index 0000000000..148edd3035 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p2.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p2{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp new file mode 100644 index 0000000000..d9918d967c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p3.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p3( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p3{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp new file mode 100644 index 0000000000..4635cdaec0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_wmma_universal_preshuffle_f8_f8_f16/device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn_default_instance_p4.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_wmma_universal_preshuffle_f8_f8_f16_mk_wmma_mn.hpp" + +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_default_instances_p4( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_wmma_f8_f8_f16_mk_wmma_mn_instances_p4{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/gemm_universal_preshuffle/CMakeLists.txt b/test/gemm_universal_preshuffle/CMakeLists.txt index 1abc4391bb..fd13826a4c 100644 --- a/test/gemm_universal_preshuffle/CMakeLists.txt +++ b/test/gemm_universal_preshuffle/CMakeLists.txt @@ -2,8 +2,8 @@ # SPDX-License-Identifier: MIT if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") - add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp) + add_gtest_executable(test_gemm_universal_preshuffle_fp8 test_gemm_universal_preshuffle_fp8.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) + target_link_libraries(test_gemm_universal_preshuffle_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) endif() endif() diff --git a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_fp8.cpp similarity index 100% rename from test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp rename to test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_fp8.cpp