[CK] Fix misc issues in CK examples (#2890)

* [CK] Fix misc CK issues

* revert fp8 change, it causes CI fail.

* resubmit fp8 change

[ROCm/composable_kernel commit: f076f207ce]
This commit is contained in:
linqunAMD
2025-09-25 02:28:20 +08:00
committed by GitHub
parent 7e537fd72f
commit 0c45597a4e
10 changed files with 74 additions and 79 deletions

View File

@@ -36,7 +36,7 @@ using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
#else
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>;
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>;
using ADataType = float;
using BDataType = float;
using CDataType = float;
@@ -185,7 +185,6 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
@@ -209,8 +208,7 @@ int main(int argc, char* argv[])
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;

View File

@@ -125,7 +125,7 @@ inline bool parse_cmd_args(int argc,
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
problem_size = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv);
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
}
else
{

View File

@@ -23,7 +23,7 @@ using RsGlobalReduceOp =
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
template <ck::index_t NDimSpatial>

View File

@@ -65,7 +65,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>;
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, S<8, 32>, 4>;
// clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
@@ -154,8 +154,8 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
int main()
{
// temp disable on gfx11 & gfx12
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
// temp disable on gfx11
if(ck::is_gfx11_supported())
{
return 0;
}

View File

@@ -62,29 +62,32 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const Block2ETileMap block_2_etile_map,
index_t NRaw)
{
#if defined(__gfx9__)
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
#if defined(__gfx9__) || defined(__gfx12__)
if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>())
{
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_welford_mean_grid,
p_welford_var_grid,
p_welford_count_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map,
NRaw);
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_welford_mean_grid,
p_welford_var_grid,
p_welford_count_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map,
NRaw);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;

View File

@@ -321,12 +321,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
float ave_time = 0;
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(arg.c_grid_desc_m_n_);
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(arg.b_grid_desc_k0_n_k1_);
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
@@ -335,8 +329,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
@@ -352,8 +345,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
@@ -367,8 +360,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
@@ -384,8 +376,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,

View File

@@ -369,11 +369,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(
descriptor,
make_tuple(make_right_pad_transform(descriptor, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(descriptor,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
@@ -616,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
// R pointer
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0];
});
}

View File

@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
@@ -32,17 +31,16 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_skip_b_lds_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__)
@@ -50,6 +48,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
@@ -67,8 +71,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
@@ -375,20 +379,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
template <bool HasMainK0BlockLoop,
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,

View File

@@ -18,14 +18,13 @@
#define CK_USE_OCP_FP8 0
#endif
#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
__HIP_DEVICE_COMPILE__
#if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
@@ -390,7 +389,7 @@ struct bf8_ocp_t
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx950__) || defined(__gfx12__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
@@ -404,7 +403,7 @@ struct bf8_ocp_t
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx950__) || defined(__gfx12__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(

View File

@@ -988,7 +988,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#if CK_OCP_FP8_CVT_FAST_PATH
// __builtin_amdgcn_cvt_pk_f32_fp8 can produce incorrect results due to a compiler issue.
// TODO: Enable when SWDEV-532959 is fixed.
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx12__)
return float2_t{__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 0),
__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 1)};
#else
@@ -1131,7 +1131,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, bf8x2_ocp_t>(bf8x2_oc
#if CK_OCP_FP8_CVT_FAST_PATH
// __builtin_amdgcn_cvt_pk_f32_bf8 can produce incorrect results due to a compiler issue.
// TODO: Enable when SWDEV-532959 is fixed.
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx12__)
return float2_t{__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 0),
__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 1)};
#else