mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Grouped conv bwd data with fp16 input and bf8fp8 comp (#962)
* Add f8 bf8 gemm example
* Add element-wise ops
* Add intrinsics
* Update reference calculation
* Add an additional type option for xdlops gemm
* Fix build process
* Add bf8 to buffer addressing
* Update blockwise op, split typeA and typeB
* Update for compatibility
* Uppdate naming to f8->fp8
* Update naming
* Format
* Update naming (#937)
* Add a client example
* Add computetypes to device and gridwise ops
* Add instances, update instance factory
* Format
* Fix a flag
* Add ckProfiler mode
* Fix typos
* Add an example
* Add bf8 generator
* add bf8 mfma; fixed type_convert for bf8
* move verfication ahead of timing
* Update reference calculation
* Fix reference
* Narrow down float init range
* Fix bf8 bf8 mfma
* Add bf8 @ fp8 mfma
* Update example
* Update instances
* Update profiler api
* Update for compatibility
* Format
* Remove extra example
* Clean up
* workaround convert
* added instance of f16_bf8f8, and client example
* fixed mfma selector
* format
---------
Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: 04f93aadb8]
This commit is contained in:
@@ -198,7 +198,9 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
ALayout, // output image
|
||||
@@ -211,7 +213,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
EDataType, // input image
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp>
|
||||
CDEElementwiseOp,
|
||||
AComputeType,
|
||||
BComputeType>
|
||||
{
|
||||
// TODO: Extend support for more spatial dimensions.
|
||||
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
|
||||
@@ -312,9 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
ABDataType,
|
||||
ABDataType,
|
||||
AComputeType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -354,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
LoopSched,
|
||||
PipelineVersion::v1,
|
||||
BComputeType>;
|
||||
|
||||
template <typename Desc_K0_M_K1>
|
||||
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
|
||||
|
||||
Reference in New Issue
Block a user