mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Add fp8 @ bf8 gemm support and example (#933)
* Add f8 bf8 gemm example
* Add element-wise ops
* Add intrinsics
* Update reference calculation
* Add an additional type option for xdlops gemm
* Fix build process
* Add bf8 to buffer addressing
* Update blockwise op, split typeA and typeB
* Update for compatibility
* Uppdate naming to f8->fp8
* Update naming
* Format
[ROCm/composable_kernel commit: bd09b5c538]
This commit is contained in:
@@ -21,7 +21,8 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename ComputType = ADataType>
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceGemm : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ComputType v_a;
|
||||
ComputType v_b;
|
||||
ComputeTypeA v_a;
|
||||
ComputeTypeB v_b;
|
||||
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
|
||||
@@ -95,7 +95,7 @@ struct GeneratorTensor_2<int8_t>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::f8_t>
|
||||
{
|
||||
@@ -143,7 +143,7 @@ struct GeneratorTensor_3<ck::bhalf_t>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::f8_t>
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user