Add conv bwd weight fp16 comp bf8 fp8 op, instances and example (#945)

* Add f8 bf8 gemm example

* Add element-wise ops

* Add intrinsics

* Update reference calculation

* Add an additional type option for xdlops gemm

* Fix build process

* Add bf8 to buffer addressing

* Update blockwise op, split typeA and typeB

* Update for compatibility

* Uppdate naming to f8->fp8

* Update naming

* Format

* Update naming (#937)

* Add a client example

* Add computetypes to device and gridwise ops

* Add instances, update instance factory

* Format

* Fix a flag

* Add ckProfiler mode

* Fix typos

* Add an example

* Add bf8 generator

* add bf8 mfma; fixed type_convert for bf8

* move verfication ahead of timing

* Update reference calculation

* Fix reference

* Narrow down float init range

* Fix bf8 bf8 mfma

* Add bf8 @ fp8 mfma

* Update example

* Update instances

* Update profiler api

* Update for compatibility

* Format

* Remove extra example

* Clean up

* workaround convert

---------

Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
Rostyslav Geyyer
2023-10-04 08:19:08 -05:00
committed by GitHub
parent e921e1f08d
commit 42facfc6b7
22 changed files with 696 additions and 106 deletions

View File

@@ -48,7 +48,8 @@ struct ComputePtrOffsetOfStridedBatch
} // namespace
template <typename GridwiseGemm,
typename FloatAB,
typename FloatA,
typename FloatB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
@@ -64,8 +65,8 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
@@ -91,7 +92,7 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
@@ -163,7 +164,9 @@ template <ck::index_t NDimSpatial,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
@@ -174,7 +177,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
OutElementwiseOperation,
ComputeTypeA,
ComputeTypeB>
{
using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle;
@@ -1045,7 +1050,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
@@ -1090,7 +1096,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
true,
1,
PipelineVersion::v1,
ComputeTypeA,
ComputeTypeB>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
@@ -1217,8 +1227,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
index_t M01_;
index_t N01_;
InElementwiseOperation a_element_op_;
OutElementwiseOperation b_element_op_;
OutElementwiseOperation a_element_op_;
InElementwiseOperation b_element_op_;
WeiElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
@@ -1281,7 +1291,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
ADataType,
BDataType,
CDataType,
OutElementwiseOperation,
InElementwiseOperation,