mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Grouped conv bwd data with fp16 input and bf8fp8 comp (#962)
* Add f8 bf8 gemm example * Add element-wise ops * Add intrinsics * Update reference calculation * Add an additional type option for xdlops gemm * Fix build process * Add bf8 to buffer addressing * Update blockwise op, split typeA and typeB * Update for compatibility * Uppdate naming to f8->fp8 * Update naming * Format * Update naming (#937) * Add a client example * Add computetypes to device and gridwise ops * Add instances, update instance factory * Format * Fix a flag * Add ckProfiler mode * Fix typos * Add an example * Add bf8 generator * add bf8 mfma; fixed type_convert for bf8 * move verfication ahead of timing * Update reference calculation * Fix reference * Narrow down float init range * Fix bf8 bf8 mfma * Add bf8 @ fp8 mfma * Update example * Update instances * Update profiler api * Update for compatibility * Format * Remove extra example * Clean up * workaround convert * added instance of f16_bf8f8, and client example * fixed mfma selector * format --------- Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -198,7 +198,9 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
ALayout, // output image
|
||||
@@ -211,7 +213,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
EDataType, // input image
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp>
|
||||
CDEElementwiseOp,
|
||||
AComputeType,
|
||||
BComputeType>
|
||||
{
|
||||
// TODO: Extend support for more spatial dimensions.
|
||||
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
|
||||
@@ -312,9 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType,
|
||||
ABDataType,
|
||||
AComputeType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -354,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
LoopSched,
|
||||
PipelineVersion::v1,
|
||||
BComputeType>;
|
||||
|
||||
template <typename Desc_K0_M_K1>
|
||||
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
|
||||
|
||||
Reference in New Issue
Block a user