mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
3d grouped conv fwd with input/output fp16 and comp fp8 (#931)
* add f8 comp instance * fixed * fixed comments * rename * fixed dtype * format * fixed CI * fixed ci * add missing ComputeType * fixed cit * fixed * Update cmake-ck-dev.sh --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -29,7 +29,8 @@ template <index_t NDimSpatial,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeType = ADataType>
|
||||
struct DeviceGroupedConvFwdMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
@@ -211,7 +211,8 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename ComputeDataType = ADataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
ALayout,
|
||||
@@ -224,7 +225,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle;
|
||||
|
||||
@@ -323,8 +325,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
|
||||
Reference in New Issue
Block a user