mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
Grouped conv bwd data with fp16 input and bf8fp8 comp (#962)
* Add f8 bf8 gemm example
* Add element-wise ops
* Add intrinsics
* Update reference calculation
* Add an additional type option for xdlops gemm
* Fix build process
* Add bf8 to buffer addressing
* Update blockwise op, split typeA and typeB
* Update for compatibility
* Uppdate naming to f8->fp8
* Update naming
* Format
* Update naming (#937)
* Add a client example
* Add computetypes to device and gridwise ops
* Add instances, update instance factory
* Format
* Fix a flag
* Add ckProfiler mode
* Fix typos
* Add an example
* Add bf8 generator
* add bf8 mfma; fixed type_convert for bf8
* move verfication ahead of timing
* Update reference calculation
* Fix reference
* Narrow down float init range
* Fix bf8 bf8 mfma
* Add bf8 @ fp8 mfma
* Update example
* Update instances
* Update profiler api
* Update for compatibility
* Format
* Remove extra example
* Clean up
* workaround convert
* added instance of f16_bf8f8, and client example
* fixed mfma selector
* format
---------
Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: 04f93aadb8]
This commit is contained in:
@@ -29,7 +29,9 @@ template <ck::index_t NDimSpatial,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
typename CDEElementwiseOperation,
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
@@ -198,7 +198,9 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
ALayout, // output image
|
||||
@@ -211,7 +213,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
EDataType, // input image
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp>
|
||||
CDEElementwiseOp,
|
||||
AComputeType,
|
||||
BComputeType>
|
||||
{
|
||||
// TODO: Extend support for more spatial dimensions.
|
||||
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
|
||||
@@ -312,9 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType,
|
||||
ABDataType,
|
||||
AComputeType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -354,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
LoopSched,
|
||||
PipelineVersion::v1,
|
||||
BComputeType>;
|
||||
|
||||
template <typename Desc_K0_M_K1>
|
||||
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace ck {
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType_,
|
||||
typename AComputeDataType_,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -72,7 +72,8 @@ template <typename ADataType,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename BComputeDataType = AComputeDataType_>
|
||||
struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -100,10 +101,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
#if CK_WORKAROUND_DENORM_FIX
|
||||
using ComputeDataType =
|
||||
conditional_t<is_same_v<ComputeDataType_, ck::half_t>, ck::bhalf_t, ComputeDataType_>;
|
||||
using AComputeDataType =
|
||||
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
|
||||
#else
|
||||
using ComputeDataType = ComputeDataType_;
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -172,8 +173,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(ComputeDataType),
|
||||
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
|
||||
b_block_space_size_aligned * sizeof(BComputeDataType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -502,7 +503,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ComputeDataType,
|
||||
AComputeDataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -533,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
BComputeDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -561,14 +562,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ComputeDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1),
|
||||
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeDataType,
|
||||
ComputeDataType,
|
||||
AComputeDataType,
|
||||
BComputeDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
@@ -586,10 +588,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
|
||||
Reference in New Issue
Block a user