[CK] Fix misc issues in CK examples (#2890)

* [CK] Fix misc CK issues

* revert fp8 change, it causes CI fail.

* resubmit fp8 change

[ROCm/composable_kernel commit: f076f207ce]
This commit is contained in:
linqunAMD
2025-09-25 02:28:20 +08:00
committed by GitHub
parent 7e537fd72f
commit 0c45597a4e
10 changed files with 74 additions and 79 deletions

View File

@@ -62,29 +62,32 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const Block2ETileMap block_2_etile_map,
index_t NRaw)
{
#if defined(__gfx9__)
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
#if defined(__gfx9__) || defined(__gfx12__)
if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>())
{
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_welford_mean_grid,
p_welford_var_grid,
p_welford_count_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map,
NRaw);
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_welford_mean_grid,
p_welford_var_grid,
p_welford_count_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map,
NRaw);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;

View File

@@ -321,12 +321,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
float ave_time = 0;
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(arg.c_grid_desc_m_n_);
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(arg.b_grid_desc_k0_n_k1_);
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
@@ -335,8 +329,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
@@ -352,8 +345,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
@@ -367,8 +360,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
@@ -384,8 +376,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,

View File

@@ -369,11 +369,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(
descriptor,
make_tuple(make_right_pad_transform(descriptor, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(descriptor,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
@@ -616,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
// R pointer
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0];
});
}

View File

@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
@@ -32,17 +31,16 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_skip_b_lds_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M_N c_grid_desc_m_n,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__)
@@ -50,6 +48,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
@@ -67,8 +71,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m_n;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
@@ -375,20 +379,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
template <bool HasMainK0BlockLoop,
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,

View File

@@ -18,14 +18,13 @@
#define CK_USE_OCP_FP8 0
#endif
#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
__HIP_DEVICE_COMPILE__
#if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
@@ -390,7 +389,7 @@ struct bf8_ocp_t
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx950__) || defined(__gfx12__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
@@ -404,7 +403,7 @@ struct bf8_ocp_t
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx950__) || defined(__gfx12__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(

View File

@@ -988,7 +988,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#if CK_OCP_FP8_CVT_FAST_PATH
// __builtin_amdgcn_cvt_pk_f32_fp8 can produce incorrect results due to a compiler issue.
// TODO: Enable when SWDEV-532959 is fixed.
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx12__)
return float2_t{__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 0),
__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 1)};
#else
@@ -1131,7 +1131,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, bf8x2_ocp_t>(bf8x2_oc
#if CK_OCP_FP8_CVT_FAST_PATH
// __builtin_amdgcn_cvt_pk_f32_bf8 can produce incorrect results due to a compiler issue.
// TODO: Enable when SWDEV-532959 is fixed.
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx12__)
return float2_t{__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 0),
__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 1)};
#else