diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index ab8ed5c2bf..c83b384a01 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -11,6 +11,12 @@ foreach(gpu IN LISTS GPU_TARGETS) if(result EQUAL 0) add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) endif() + if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) + if(result EQUAL 0) + add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) + endif() + endif() set(target 1) endif() endforeach() diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index 15727495f0..0e7eb00e4b 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -23,6 +23,12 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif template using S = ck::Sequence; diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp index cb6d68e04a..dccde28aa0 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp @@ -65,6 +65,15 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedC 5, // CThreadTransferSrcDstVectorDim 4>; // CThreadTransferDstScalarPerVector +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index 31e277b5c7..467209ef28 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -67,6 +67,15 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp index 69c831cc54..8b751b0d98 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -66,6 +66,15 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000..c678867e74 --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; +using ComputeTypeA = BF8; +using ComputeTypeB = F8; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 2, // CBlockTransferScalarPerVector_NWaveNPerXdl + ComputeTypeA, // ComputeTypeA + ComputeTypeB>; // ComputeTypeB + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 34bca7dc86..755d6addd8 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -1,15 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -template -using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; - template bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_param) @@ -46,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - out.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in.GenerateTensorValue(GeneratorTensor_3{0.0, 0.2}); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); } DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); @@ -113,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return true; } - float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = conv_param.GetFlops(); - std::size_t num_btype = conv_param.GetByte(); - - float tflops = static_cast(flop) / 1.E9 / avg_time; - - float gb_per_sec = num_btype / 1.E6 / avg_time; - - std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl - << "DeviceOp: " << conv.GetTypeString() << std::endl; + invoker.Run(argument, StreamConfig{nullptr, false}); if(config.do_verification) { @@ -148,6 +128,19 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); } + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl + << "DeviceOp: " << conv.GetTypeString() << std::endl; + return true; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index ab9e6adb41..7296e4faaa 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -20,7 +20,9 @@ template + typename OutElementwiseOperation, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight : public BaseOperator { virtual std::unique_ptr diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 71c528e4c9..d4bba3bf64 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch } // namespace template (compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; GridwiseGemm::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, @@ -163,7 +164,9 @@ template + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight_Xdl_CShuffle : public DeviceGroupedConvBwdWeight + OutElementwiseOperation, + ComputeTypeA, + ComputeTypeB> { using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; @@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, @@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, - true>; + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle index_t M01_; index_t N01_; - InElementwiseOperation a_element_op_; - OutElementwiseOperation b_element_op_; + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; WeiElementwiseOperation c_element_op_; // for checking IsSupportedArgument() @@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype + ADataType, + BDataType, CDataType, OutElementwiseOperation, InElementwiseOperation, diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 927574dfde..337117551f 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -185,7 +185,8 @@ struct PassThrough template <> __host__ __device__ void operator()(bf8_t& y, const half_t& x) const { - y = type_convert(x); + // to-do: fix half_t to bf8_t convert + y = ck::type_convert(ck::type_convert(x)); } #endif }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 8897ce5a10..06c87d1892 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen } template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight { static constexpr auto I0 = Number<0>{}; @@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // denorm test fix, required to work around fp16 mfma issue // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update - // FloatABAdjusted -> FloatAB throughout this file + // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, + // throughout this file #if CK_WORKAROUND_DENORM_FIX - using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; + using FloatAAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeA>; + using FloatBAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeB>; #else - using FloatABAdjusted = FloatAB; + using FloatAAdjusted = ComputeTypeA; + using FloatBAdjusted = ComputeTypeB; #endif // M0/M1/M1Padding @@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + return math::max((a_block_space_size * sizeof(FloatAAdjusted) + + b_block_space_size * sizeof(FloatBAdjusted)), c_block_size * sizeof(FloatC)); } @@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, @@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, MPerBlock, K1>, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatA, + FloatAAdjusted, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, NPerBlock, K1>, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatB, + FloatBAdjusted, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -733,12 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // sanity check constexpr index_t KPack = - math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + math::max(K1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size, + static_cast(p_shared) + a_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize()); // gridwise GEMM pipeline diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 06b95aef21..c8e56fbc56 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -32,8 +32,12 @@ enum struct MfmaInstr mfma_f64_16x16x4f64, mfma_f32_32x32x16f8f8, mfma_f32_16x16x32f8f8, + mfma_f32_32x32x16bf8bf8, + mfma_f32_16x16x32bf8bf8, mfma_f32_32x32x16f8bf8, - mfma_f32_16x16x32f8bf8 + mfma_f32_16x16x32f8bf8, + mfma_f32_32x32x16bf8f8, + mfma_f32_16x16x32bf8f8 }; template @@ -504,6 +508,52 @@ struct mfma_type }; #endif +#if defined CK_ENABLE_BF8 +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8bf8::Run(a, b, reg_c); + } +}; +#endif + #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template <> struct mfma_type @@ -550,6 +600,52 @@ struct mfma_type }; #endif +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8f8::Run(a, b, reg_c); + } +}; +#endif + template + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8bf8; + } +#endif + #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template <> static constexpr auto GetMfma() @@ -724,6 +834,20 @@ struct MfmaSelector } #endif +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8f8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8f8; + } +#endif + static constexpr auto selected_mfma = mfma_type()>{}; @@ -931,8 +1055,12 @@ struct XdlopsGemm #if defined CK_ENABLE_FP8 || is_same::value #endif +#if defined CK_ENABLE_BF8 + || is_same::value +#endif #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 - || (is_same::value && is_same::value) + || (is_same::value && is_same::value) || + (is_same::value && is_same::value) #endif , "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 3768f92633..bf6241e46e 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -420,6 +420,71 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> }; #endif +#if defined CK_ENABLE_BF8 +template +struct intrin_mfma_f32_32x32x16bf8bf8; + +template <> +struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8bf8; + +template <> +struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; +#endif + #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 template struct intrin_mfma_f32_32x32x16f8bf8; @@ -484,5 +549,70 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> } }; #endif + +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +template +struct intrin_mfma_f32_32x32x16bf8f8; + +template <> +struct intrin_mfma_f32_32x32x16bf8f8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8f8; + +template <> +struct intrin_mfma_f32_16x16x32bf8f8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; +#endif } // namespace ck #endif diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 935926070d..505de73ae2 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -221,7 +221,7 @@ inline __host__ __device__ bf8_t type_convert(half_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion - return type_convert(type_convert(x)); + return type_convert(type_convert(x)); #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp index ec5df238ab..02ad7a033a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp @@ -25,6 +25,8 @@ template = 1 && NDimSpatial <= 3, bool>::type = false> struct ReferenceConvBwdWeight : public device::BaseOperator { @@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator if(wi >= 0 && ck::type_convert(wi) < arg.input_.GetLengths()[3]) { - float v_out; - float v_in; + ComputeTypeA v_out; + ComputeTypeB v_in; arg.out_element_op_( v_out, ck::type_convert(arg.output_(g, n, k, wo))); @@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator arg.in_element_op_( v_in, ck::type_convert(arg.input_(g, n, c, wi))); - v_acc += v_out * v_in; + v_acc += type_convert(v_out) * type_convert(v_in); } } } @@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator wi >= 0 && ck::type_convert(wi) < arg.input_.GetLengths()[4]) { - float v_out; - float v_in; + ComputeTypeA v_out; + ComputeTypeB v_in; arg.out_element_op_( v_out, @@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator arg.in_element_op_( v_in, ck::type_convert(arg.input_(g, n, c, hi, wi))); - v_acc += v_out * v_in; + v_acc += type_convert(v_out) * type_convert(v_in); } } } @@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ck::type_convert(wi) < arg.input_.GetLengths()[5]) { - float v_out; - float v_in; + ComputeTypeA v_out; + ComputeTypeB v_in; arg.out_element_op_(v_out, ck::type_convert( @@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ck::type_convert( arg.input_(g, n, c, di, hi, wi))); - v_acc += v_out * v_in; + v_acc += + type_convert(v_out) * type_convert(v_in); } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp index 6fc91565b9..096e0b1770 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp @@ -19,6 +19,14 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + using Empty_Tuple = ck::Tuple<>; template @@ -133,6 +141,43 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances = std::tuple< +// clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, + // instance for small conv.K + // for fp16 conv.K and conv.C must be divisible by 2 + // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index c0a5b4bfe5..9f536f210d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -216,6 +216,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances PassThrough, PassThrough>>>& instances); #endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances); +#endif #ifdef DL_KERNELS // dl @@ -464,7 +479,9 @@ template + typename OutDataType, + typename ComputeTypeA, + typename ComputeTypeB> struct DeviceOperationInstanceFactory> + ck::tensor_operation::element_wise::PassThrough, + ComputeTypeA, + ComputeTypeB>> { using DeviceOp = DeviceGroupedConvBwdWeight; + ck::tensor_operation::element_wise::PassThrough, + ComputeTypeA, + ComputeTypeB>; static auto GetInstances() { @@ -706,7 +727,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && + is_same_v && + is_same_v) { #ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( @@ -728,6 +751,15 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + op_ptrs); + } #endif } } diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/library/include/ck/library/utility/host_tensor_generator.hpp index 07d3a35a18..6b531c4c6f 100644 --- a/library/include/ck/library/utility/host_tensor_generator.hpp +++ b/library/include/ck/library/utility/host_tensor_generator.hpp @@ -111,6 +111,22 @@ struct GeneratorTensor_2 }; #endif +#if defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bf8_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; +#endif + template struct GeneratorTensor_3 { @@ -162,6 +178,25 @@ struct GeneratorTensor_3 }; #endif +#if defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bf8_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; +#endif + template struct GeneratorTensor_4 { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 6bd9d6e646..65ea645a39 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -4,7 +4,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) + device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) if(DL_KERNELS) list(APPEND GROUPED_CONV3D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp new file mode 100644 index 0000000000..7f9493f602 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 48bf639a70..0300aed436 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -33,7 +33,9 @@ template + typename OutDataType, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> bool profile_grouped_conv_bwd_weight_impl(int do_verification, int init_method, bool do_log, @@ -120,7 +122,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>; + OutElementOp, + ComputeTypeA, + ComputeTypeB>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index be8a3230f7..9a9b5a0f74 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -20,9 +20,10 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_F32_BF16, // 2 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_F32_BF16, // 2 + F16_F16_F16_BF8_F8 // 3 }; #define OP_NAME "grouped_conv_bwd_weight" @@ -33,7 +34,8 @@ static void print_helper_msg() std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << " 1: Input fp16, Weight fp16, Output fp16\n" - << " 2: Input bf16, Weight fp32, Output bf16)\n" + << " 2: Input bf16, Weight fp32, Output bf16\n" + << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n" << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " "N, K, Ho, Wo]\n" << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " @@ -82,6 +84,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; +#ifdef CK_ENABLE_FP8 + using F8 = ck::f8_t; +#endif +#ifdef CK_ENABLE_BF8 + using BF8 = ck::bf8_t; +#endif using namespace ck::tensor_layout::convolution; @@ -95,7 +103,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) auto out_layout, auto in_type, auto wei_type, - auto out_type) { + auto out_type, + auto compute_type_a, + auto compute_type_b) { constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; using InLayout = decltype(in_layout); @@ -106,13 +116,18 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using WeiDataType = decltype(wei_type); using OutDataType = decltype(out_type); + using ComputeTypeA = decltype(compute_type_a); + using ComputeTypeB = decltype(compute_type_b); + bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl( + OutDataType, + ComputeTypeA, + ComputeTypeB>( do_verification, init_method, do_log, time_kernel, params, split_k); return pass ? 0 : 1; @@ -122,80 +137,84 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_F32_BF16) { // fp32 atomic add is used for weight tensor in bf16 kernel - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F16_F16_F16_BF8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); } }