diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index e0124e57f7..703d0e3834 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -66,21 +66,17 @@ endif() add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) - if(result EQUAL 0) +add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_xdl example_gemm_xdl_fp8) - endif() endif() -if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) - if(result EQUAL 0) +add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) - endif() endif() add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) endif() diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index 1015926777..2d4df3fc13 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -7,9 +7,9 @@ using ADataType = ck::f8_t; using BDataType = ck::f8_t; -using CDataType = ck::f8_t; +using CDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = ck::f8_t; +using CShuffleDataType = float; using ALayout = Row; using BLayout = Col; @@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp index 0d69b7a90f..b54df8ff3d 100644 --- a/example/01_gemm/gemm_xdl_fp8_bf8.cpp +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -7,9 +7,9 @@ using ADataType = ck::f8_t; using BDataType = ck::bf8_t; -using CDataType = ck::f8_t; +using CDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = ck::f8_t; +using CShuffleDataType = float; using ALayout = Row; using BLayout = Col; @@ -31,7 +31,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm __host__ __device__ void operator()(bf8_t& y, const half_t& x) const { - // to-do: fix half_t to bf8_t convert - y = ck::type_convert(ck::type_convert(x)); + y = ck::type_convert(x); } #endif }; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 505de73ae2..8b70a6bfb4 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -344,7 +344,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion return f8_convert_sr(type_convert(x)); -#else +#elif 0 constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; @@ -353,6 +353,8 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) return utils:: cast_to_f8( x, rng); +#else + return type_convert(type_convert(x)); #endif } #endif @@ -393,7 +395,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion return f8_convert_sr(type_convert(x)); -#else +#elif 0 constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; @@ -403,6 +405,8 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) return utils:: cast_to_f8( x, rng); +#else + return type_convert(type_convert(x)); #endif } #endif