mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Add conv bwd weight fp16 comp bf8 fp8 op, instances and example (#945)
* Add f8 bf8 gemm example
* Add element-wise ops
* Add intrinsics
* Update reference calculation
* Add an additional type option for xdlops gemm
* Fix build process
* Add bf8 to buffer addressing
* Update blockwise op, split typeA and typeB
* Update for compatibility
* Uppdate naming to f8->fp8
* Update naming
* Format
* Update naming (#937)
* Add a client example
* Add computetypes to device and gridwise ops
* Add instances, update instance factory
* Format
* Fix a flag
* Add ckProfiler mode
* Fix typos
* Add an example
* Add bf8 generator
* add bf8 mfma; fixed type_convert for bf8
* move verfication ahead of timing
* Update reference calculation
* Fix reference
* Narrow down float init range
* Fix bf8 bf8 mfma
* Add bf8 @ fp8 mfma
* Update example
* Update instances
* Update profiler api
* Update for compatibility
* Format
* Remove extra example
* Clean up
* workaround convert
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: 42facfc6b7]
This commit is contained in:
@@ -11,6 +11,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
endif()
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -23,6 +23,12 @@
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF8
|
||||
using BF8 = ck::bf8_t;
|
||||
#endif
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
@@ -65,6 +65,15 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedC
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
4>; // CThreadTransferDstScalarPerVector
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
|
||||
@@ -67,6 +67,15 @@ using DeviceConvBwdWeightInstance =
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
|
||||
@@ -66,6 +66,15 @@ using DeviceConvBwdWeightInstance =
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
using OutDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using ComputeTypeA = BF8;
|
||||
using ComputeTypeB = F8;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
1, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
1, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
2, // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
ComputeTypeA, // ComputeTypeA
|
||||
ComputeTypeB>; // ComputeTypeB
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
@@ -1,15 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
@@ -46,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 0.2});
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.1, 0.1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
|
||||
@@ -113,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
return true;
|
||||
}
|
||||
|
||||
float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl
|
||||
<< "DeviceOp: " << conv.GetTypeString() << std::endl;
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
@@ -148,6 +128,19 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData);
|
||||
}
|
||||
|
||||
float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl
|
||||
<< "DeviceOp: " << conv.GetTypeString() << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ template <ck::index_t NDimSpatial,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
typename OutElementwiseOperation,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGroupedConvBwdWeight : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
|
||||
@@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch
|
||||
} // namespace
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
@@ -64,8 +65,8 @@ __global__ void
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_bwd_weight(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
@@ -91,7 +92,7 @@ __global__ void
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
|
||||
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
|
||||
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
@@ -163,7 +164,9 @@ template <ck::index_t NDimSpatial,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
@@ -174,7 +177,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
OutElementwiseOperation,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle;
|
||||
|
||||
@@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
@@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true>;
|
||||
true,
|
||||
1,
|
||||
PipelineVersion::v1,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
@@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
|
||||
InElementwiseOperation a_element_op_;
|
||||
OutElementwiseOperation b_element_op_;
|
||||
OutElementwiseOperation a_element_op_;
|
||||
InElementwiseOperation b_element_op_;
|
||||
WeiElementwiseOperation c_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
@@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
|
||||
@@ -185,7 +185,8 @@ struct PassThrough
|
||||
template <>
|
||||
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
|
||||
{
|
||||
y = type_convert<bf8_t>(x);
|
||||
// to-do: fix half_t to bf8_t convert
|
||||
y = ck::type_convert<bf8_t>(ck::type_convert<float>(x));
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
@@ -153,8 +154,8 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
@@ -181,21 +182,22 @@ __global__ void
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_b_k0_m_k1_grid_desc;
|
||||
ignore = b_b_k0_n_k1_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c_block_cluster_adaptor;
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_b_k0_m_k1_grid_desc;
|
||||
ignore = b_b_k0_n_k1_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c_block_cluster_adaptor;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
@@ -242,7 +244,9 @@ template <index_t BlockSize,
|
||||
bool ABlockLdsExtraM1Wrw = false,
|
||||
bool BBlockLdsExtraN1Wrw = false,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename ComputeTypeA = FloatA,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// denorm test fix, required to work around fp16 mfma issue
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// FloatABAdjusted -> FloatAB throughout this file
|
||||
// FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
|
||||
// throughout this file
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
|
||||
using FloatAAdjusted =
|
||||
conditional_t<is_same_v<ComputeTypeA, ck::half_t>, ck::bhalf_t, ComputeTypeA>;
|
||||
using FloatBAdjusted =
|
||||
conditional_t<is_same_v<ComputeTypeB, ck::half_t>, ck::bhalf_t, ComputeTypeB>;
|
||||
#else
|
||||
using FloatABAdjusted = FloatAB;
|
||||
using FloatAAdjusted = ComputeTypeA;
|
||||
using FloatBAdjusted = ComputeTypeB;
|
||||
#endif
|
||||
|
||||
// M0/M1/M1Padding
|
||||
@@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
constexpr auto c_block_size =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB),
|
||||
return math::max((a_block_space_size * sizeof(FloatAAdjusted) +
|
||||
b_block_space_size * sizeof(FloatBAdjusted)),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
@@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
@@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatABAdjusted,
|
||||
FloatA,
|
||||
FloatAAdjusted,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatABAdjusted,
|
||||
FloatB,
|
||||
FloatBAdjusted,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -733,12 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// sanity check
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
|
||||
math::max(K1,
|
||||
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatABAdjusted,
|
||||
FloatABAdjusted,
|
||||
FloatAAdjusted,
|
||||
FloatBAdjusted,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
@@ -758,10 +770,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
|
||||
static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
|
||||
@@ -32,8 +32,12 @@ enum struct MfmaInstr
|
||||
mfma_f64_16x16x4f64,
|
||||
mfma_f32_32x32x16f8f8,
|
||||
mfma_f32_16x16x32f8f8,
|
||||
mfma_f32_32x32x16bf8bf8,
|
||||
mfma_f32_16x16x32bf8bf8,
|
||||
mfma_f32_32x32x16f8bf8,
|
||||
mfma_f32_16x16x32f8bf8
|
||||
mfma_f32_16x16x32f8bf8,
|
||||
mfma_f32_32x32x16bf8f8,
|
||||
mfma_f32_16x16x32bf8f8
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -504,6 +508,52 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x16bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
|
||||
@@ -550,6 +600,52 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x16bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
@@ -710,6 +806,20 @@ struct MfmaSelector
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
|
||||
@@ -724,6 +834,20 @@ struct MfmaSelector
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
}
|
||||
#endif
|
||||
|
||||
static constexpr auto selected_mfma =
|
||||
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
|
||||
|
||||
@@ -931,8 +1055,12 @@ struct XdlopsGemm
|
||||
#if defined CK_ENABLE_FP8
|
||||
|| is_same<base_type, f8_t>::value
|
||||
#endif
|
||||
#if defined CK_ENABLE_BF8
|
||||
|| is_same<base_type, bf8_t>::value
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
|| (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value)
|
||||
|| (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
|
||||
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value)
|
||||
#endif
|
||||
,
|
||||
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
|
||||
|
||||
@@ -420,6 +420,71 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16bf8bf8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<bf8_t, 8> reg_a_v(reg_a);
|
||||
vector_type<bf8_t, 8> reg_b_v(reg_b);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
|
||||
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
|
||||
|
||||
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x32bf8bf8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<bf8_t, 8> reg_a_v(reg_a);
|
||||
vector_type<bf8_t, 8> reg_b_v(reg_b);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
|
||||
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
|
||||
|
||||
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16f8bf8;
|
||||
@@ -484,5 +549,70 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16bf8f8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<bf8_t, 8> reg_a_v(reg_a);
|
||||
vector_type<f8_t, 8> reg_b_v(reg_b);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
|
||||
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
|
||||
|
||||
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x32bf8f8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<bf8_t, 8> reg_a_v(reg_a);
|
||||
vector_type<f8_t, 8> reg_b_v(reg_b);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
|
||||
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
|
||||
|
||||
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -221,7 +221,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
return type_convert<bf8_t>(type_convert<float>(x));
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
|
||||
@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename ComputeTypeA = OutDataType,
|
||||
typename ComputeTypeB = InDataType,
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
{
|
||||
@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
if(wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
ComputeTypeA v_out;
|
||||
ComputeTypeB v_in;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
|
||||
@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
ComputeTypeA v_out;
|
||||
ComputeTypeB v_in;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out,
|
||||
@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
ck::type_convert<std::size_t>(wi) <
|
||||
arg.input_.GetLengths()[5])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
ComputeTypeA v_out;
|
||||
ComputeTypeB v_in;
|
||||
|
||||
arg.out_element_op_(v_out,
|
||||
ck::type_convert<float>(
|
||||
@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
ck::type_convert<float>(
|
||||
arg.input_(g, n, c, di, hi, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
v_acc +=
|
||||
type_convert<float>(v_out) * type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,14 @@ using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF8
|
||||
using BF8 = ck::bf8_t;
|
||||
#endif
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
@@ -133,6 +141,43 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>,
|
||||
// instance for small conv.K
|
||||
// for fp16 conv.K and conv.C must be divisible by 2
|
||||
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
|
||||
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -216,6 +216,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
// dl
|
||||
@@ -464,7 +479,9 @@ template <ck::index_t NumDimSpatial,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
typename OutDataType,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvBwdWeight<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -475,7 +492,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight<NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -486,7 +505,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
@@ -706,7 +727,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
is_same_v<OutDataType, half_t> &&
|
||||
is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
@@ -728,6 +751,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> &&
|
||||
is_same_v<ComputeTypeA, bf8_t> && is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +111,22 @@ struct GeneratorTensor_2<ck::f8_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::bf8_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf8_t operator()(Is...)
|
||||
{
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
return ck::type_convert<ck::bf8_t>(tmp);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
@@ -162,6 +178,25 @@ struct GeneratorTensor_3<ck::f8_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::bf8_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf8_t operator()(Is...)
|
||||
{
|
||||
float tmp = float(std::rand()) / float(RAND_MAX);
|
||||
|
||||
float fp32_tmp = min_value + tmp * (max_value - min_value);
|
||||
|
||||
return ck::type_convert<ck::bf8_t>(fp32_tmp);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
|
||||
@@ -4,7 +4,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -33,7 +33,9 @@ template <ck::index_t NDimSpatial,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
typename OutDataType,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -120,7 +122,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
OutElementOp,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
|
||||
@@ -20,9 +20,10 @@ enum struct ConvLayout
|
||||
|
||||
enum struct ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_F32_BF16, // 2
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_F32_BF16, // 2
|
||||
F16_F16_F16_BF8_F8 // 3
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_bwd_weight"
|
||||
@@ -33,7 +34,8 @@ static void print_helper_msg()
|
||||
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
|
||||
<< " 1: Input fp16, Weight fp16, Output fp16\n"
|
||||
<< " 2: Input bf16, Weight fp32, Output bf16)\n"
|
||||
<< " 2: Input bf16, Weight fp32, Output bf16\n"
|
||||
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
|
||||
"N, K, Ho, Wo]\n"
|
||||
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
|
||||
@@ -82,6 +84,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF8
|
||||
using BF8 = ck::bf8_t;
|
||||
#endif
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
@@ -95,7 +103,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
auto out_layout,
|
||||
auto in_type,
|
||||
auto wei_type,
|
||||
auto out_type) {
|
||||
auto out_type,
|
||||
auto compute_type_a,
|
||||
auto compute_type_b) {
|
||||
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
@@ -106,13 +116,18 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
|
||||
using ComputeTypeA = decltype(compute_type_a);
|
||||
using ComputeTypeB = decltype(compute_type_b);
|
||||
|
||||
bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
OutDataType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>(
|
||||
do_verification, init_method, do_log, time_kernel, params, split_k);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
@@ -122,80 +137,84 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{});
|
||||
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{});
|
||||
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{});
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{});
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{});
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{});
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{});
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{});
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{});
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user