[CK] Fix misc issues in CK examples (#2890)

* [CK] Fix misc CK issues

* revert fp8 change, it causes CI fail.

* resubmit fp8 change
This commit is contained in:
linqunAMD
2025-09-25 02:28:20 +08:00
committed by GitHub
parent 8fe3838c65
commit f076f207ce
10 changed files with 74 additions and 79 deletions

View File

@@ -62,29 +62,32 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const Block2ETileMap block_2_etile_map,
index_t NRaw)
{
#if defined(__gfx9__)
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
#if defined(__gfx9__) || defined(__gfx12__)
if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>())
{
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_welford_mean_grid,
p_welford_var_grid,
p_welford_count_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map,
NRaw);
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_welford_mean_grid,
p_welford_var_grid,
p_welford_count_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map,
NRaw);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;

View File

@@ -321,12 +321,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
float ave_time = 0;
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(arg.c_grid_desc_m_n_);
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(arg.b_grid_desc_k0_n_k1_);
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
@@ -335,8 +329,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
@@ -352,8 +345,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
@@ -367,8 +360,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
CDataType,
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
@@ -384,8 +376,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,

View File

@@ -369,11 +369,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(
descriptor,
make_tuple(make_right_pad_transform(descriptor, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(descriptor,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
@@ -616,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
// R pointer
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0];
});
}