mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
Introduce MX GEMM for FP8 data type (#2000)
This commit is contained in:
committed by
GitHub
parent
c027637a8f
commit
6660dc6b8e
@@ -189,15 +189,36 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
const ElementwiseOperation element_op_;
|
||||
}; // namespace ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is not known at compile-time
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_slice_origin_idx is not known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. dst_slice_origin_idx is known at compile-time
|
||||
/**
|
||||
* @brief Helper structure that facilitates transfer of source (grid) data to destination threads.
|
||||
*
|
||||
* @details The following assumptions are made:
|
||||
* - For Source (Grid) Data:
|
||||
* 1. The source tensor descriptor SrcDesc is not known at compile-time.
|
||||
* 2. The source buffer is a dynamic buffer.
|
||||
* 3. The source slice origin index src_slice_origin_idx is not known at compile-time.
|
||||
* - For Destination (Thread) Data:
|
||||
* 1. The destination tensor descriptor DstDesc is known at compile-time.
|
||||
* 2. The destination buffer dst_buf is a static buffer.
|
||||
* 3. The destination slice origin index dst_slice_origin_idx is known at compile-time.
|
||||
*
|
||||
* @tparam SrcData The data type of the source tensor.
|
||||
* @tparam DstData The data type of the destination tensor.
|
||||
* @tparam SrcDesc The descriptor type of the source tensor.
|
||||
* @tparam DstDesc The descriptor type of the destination tensor.
|
||||
* @tparam SliceLengths The lengths of the slice to be transferred.
|
||||
* @tparam DimAccessOrder The order of dimension access for the space-filling curve.
|
||||
* @tparam SrcVectorDim The dimension along which vectorized access is performed in the source
|
||||
* tensor.
|
||||
* @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor.
|
||||
* @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source
|
||||
* tensor.
|
||||
* @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run
|
||||
* or rolled back one step in MoveSrcSliceWindow
|
||||
* @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for
|
||||
* floating-point types).
|
||||
*
|
||||
*/
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
|
||||
Reference in New Issue
Block a user