Introduce MX GEMM for FP8 data type (#2000)

This commit is contained in:
Andriy Roshchenko
2025-03-24 15:41:07 -06:00
committed by GitHub
parent c027637a8f
commit 6660dc6b8e
11 changed files with 4129 additions and 135 deletions

View File

@@ -189,15 +189,36 @@ struct ThreadwiseTensorSliceTransfer_v1r3
const ElementwiseOperation element_op_;
}; // namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume:
// 1. src:
// 1. SrcDesc is not known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. src_slice_origin_idx is not known at compile-time
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. dst_slice_origin_idx is known at compile-time
/**
* @brief Helper structure that facilitates transfer of source (grid) data to destination threads.
*
* @details The following assumptions are made:
* - For Source (Grid) Data:
* 1. The source tensor descriptor SrcDesc is not known at compile-time.
* 2. The source buffer is a dynamic buffer.
* 3. The source slice origin index src_slice_origin_idx is not known at compile-time.
* - For Destination (Thread) Data:
* 1. The destination tensor descriptor DstDesc is known at compile-time.
* 2. The destination buffer dst_buf is a static buffer.
* 3. The destination slice origin index dst_slice_origin_idx is known at compile-time.
*
* @tparam SrcData The data type of the source tensor.
* @tparam DstData The data type of the destination tensor.
* @tparam SrcDesc The descriptor type of the source tensor.
* @tparam DstDesc The descriptor type of the destination tensor.
* @tparam SliceLengths The lengths of the slice to be transferred.
* @tparam DimAccessOrder The order of dimension access for the space-filling curve.
* @tparam SrcVectorDim The dimension along which vectorized access is performed in the source
* tensor.
* @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor.
* @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source
* tensor.
* @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run
* or rolled back one step in MoveSrcSliceWindow
* @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for
* floating-point types).
*
*/
template <typename SrcData,
typename DstData,
typename SrcDesc,