mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK] Fix misc issues in CK examples (#2890)
* [CK] Fix misc CK issues
* revert fp8 change, it causes CI fail.
* resubmit fp8 change
[ROCm/composable_kernel commit: f076f207ce]
This commit is contained in:
@@ -36,7 +36,7 @@ using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
#else
|
||||
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>;
|
||||
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>;
|
||||
using ADataType = float;
|
||||
using BDataType = float;
|
||||
using CDataType = float;
|
||||
@@ -185,7 +185,6 @@ int main(int argc, char* argv[])
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
@@ -209,8 +208,7 @@ int main(int argc, char* argv[])
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
@@ -125,7 +125,7 @@ inline bool parse_cmd_args(int argc,
|
||||
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
problem_size = ck::utils::conv::parse_conv_param(
|
||||
num_dim_spatial, threshold_to_catch_partial_args, argv);
|
||||
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -23,7 +23,7 @@ using RsGlobalReduceOp =
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// clang-format off
|
||||
template <ck::index_t NDimSpatial>
|
||||
|
||||
@@ -65,7 +65,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
|
||||
//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>;
|
||||
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, S<8, 32>, 4>;
|
||||
// clang-format on
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
@@ -154,8 +154,8 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
|
||||
|
||||
int main()
|
||||
{
|
||||
// temp disable on gfx11 & gfx12
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
// temp disable on gfx11
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -62,29 +62,32 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const Block2ETileMap block_2_etile_map,
|
||||
index_t NRaw)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
#if defined(__gfx9__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>())
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_welford_mean_grid,
|
||||
p_welford_var_grid,
|
||||
p_welford_count_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
block_2_etile_map,
|
||||
NRaw);
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_welford_mean_grid,
|
||||
p_welford_var_grid,
|
||||
p_welford_count_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
block_2_etile_map,
|
||||
NRaw);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
|
||||
@@ -321,12 +321,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(arg.c_grid_desc_m_n_);
|
||||
|
||||
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
|
||||
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(arg.b_grid_desc_k0_n_k1_);
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
|
||||
@@ -335,8 +329,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
@@ -352,8 +345,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
@@ -367,8 +360,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
@@ -384,8 +376,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
|
||||
@@ -369,11 +369,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M
|
||||
return transform_tensor_descriptor(
|
||||
descriptor,
|
||||
make_tuple(make_right_pad_transform(descriptor, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return transform_tensor_descriptor(descriptor,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -616,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
|
||||
|
||||
// R pointer
|
||||
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
|
||||
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
|
||||
compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0];
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
@@ -32,17 +31,16 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_skip_b_lds_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N c_grid_desc_m_n,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -50,6 +48,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
|
||||
|
||||
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
|
||||
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -67,8 +71,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
|
||||
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = c_grid_desc_m_n;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
@@ -375,20 +379,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
|
||||
decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
template <bool HasMainK0BlockLoop,
|
||||
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
|
||||
@@ -18,14 +18,13 @@
|
||||
#define CK_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
|
||||
__HIP_DEVICE_COMPILE__
|
||||
#if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
|
||||
#define CK_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
#define CK_FP8_CVT_FAST_PATH 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
|
||||
#if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
|
||||
#define CK_OCP_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
#define CK_OCP_FP8_CVT_FAST_PATH 0
|
||||
@@ -390,7 +389,7 @@ struct bf8_ocp_t
|
||||
__host__ explicit operator float() const
|
||||
#endif
|
||||
{
|
||||
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<float, wm, we, false>(
|
||||
@@ -404,7 +403,7 @@ struct bf8_ocp_t
|
||||
__host__ explicit operator _Float16() const
|
||||
#endif
|
||||
{
|
||||
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
|
||||
|
||||
@@ -988,7 +988,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
// __builtin_amdgcn_cvt_pk_f32_fp8 can produce incorrect results due to a compiler issue.
|
||||
// TODO: Enable when SWDEV-532959 is fixed.
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx12__)
|
||||
return float2_t{__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 0),
|
||||
__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 1)};
|
||||
#else
|
||||
@@ -1131,7 +1131,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, bf8x2_ocp_t>(bf8x2_oc
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
// __builtin_amdgcn_cvt_pk_f32_bf8 can produce incorrect results due to a compiler issue.
|
||||
// TODO: Enable when SWDEV-532959 is fixed.
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx12__)
|
||||
return float2_t{__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 0),
|
||||
__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 1)};
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user