mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
Add conv bwd weight fp16 comp bf8 fp8 op, instances and example (#945)
* Add f8 bf8 gemm example * Add element-wise ops * Add intrinsics * Update reference calculation * Add an additional type option for xdlops gemm * Fix build process * Add bf8 to buffer addressing * Update blockwise op, split typeA and typeB * Update for compatibility * Uppdate naming to f8->fp8 * Update naming * Format * Update naming (#937) * Add a client example * Add computetypes to device and gridwise ops * Add instances, update instance factory * Format * Fix a flag * Add ckProfiler mode * Fix typos * Add an example * Add bf8 generator * add bf8 mfma; fixed type_convert for bf8 * move verfication ahead of timing * Update reference calculation * Fix reference * Narrow down float init range * Fix bf8 bf8 mfma * Add bf8 @ fp8 mfma * Update example * Update instances * Update profiler api * Update for compatibility * Format * Remove extra example * Clean up * workaround convert --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename ComputeTypeA = OutDataType,
|
||||
typename ComputeTypeB = InDataType,
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
{
|
||||
@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
if(wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
ComputeTypeA v_out;
|
||||
ComputeTypeB v_in;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
|
||||
@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
ComputeTypeA v_out;
|
||||
ComputeTypeB v_in;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out,
|
||||
@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
ck::type_convert<std::size_t>(wi) <
|
||||
arg.input_.GetLengths()[5])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
ComputeTypeA v_out;
|
||||
ComputeTypeB v_in;
|
||||
|
||||
arg.out_element_op_(v_out,
|
||||
ck::type_convert<float>(
|
||||
@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
ck::type_convert<float>(
|
||||
arg.input_(g, n, c, di, hi, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
v_acc +=
|
||||
type_convert<float>(v_out) * type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,14 @@ using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF8
|
||||
using BF8 = ck::bf8_t;
|
||||
#endif
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
@@ -133,6 +141,43 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>,
|
||||
// instance for small conv.K
|
||||
// for fp16 conv.K and conv.C must be divisible by 2
|
||||
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
|
||||
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -216,6 +216,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
// dl
|
||||
@@ -464,7 +479,9 @@ template <ck::index_t NumDimSpatial,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
typename OutDataType,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvBwdWeight<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -475,7 +492,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight<NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -486,7 +505,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
@@ -706,7 +727,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
is_same_v<OutDataType, half_t> &&
|
||||
is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
@@ -728,6 +751,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> &&
|
||||
is_same_v<ComputeTypeA, bf8_t> && is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +111,22 @@ struct GeneratorTensor_2<ck::f8_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::bf8_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf8_t operator()(Is...)
|
||||
{
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
return ck::type_convert<ck::bf8_t>(tmp);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
@@ -162,6 +178,25 @@ struct GeneratorTensor_3<ck::f8_t>
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_BF8
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::bf8_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bf8_t operator()(Is...)
|
||||
{
|
||||
float tmp = float(std::rand()) / float(RAND_MAX);
|
||||
|
||||
float fp32_tmp = min_value + tmp * (max_value - min_value);
|
||||
|
||||
return ck::type_convert<ck::bf8_t>(fp32_tmp);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user