mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK] Fix misc issues in CK examples (#2890)
* [CK] Fix misc CK issues * revert fp8 change, it causes CI fail. * resubmit fp8 change
This commit is contained in:
@@ -62,29 +62,32 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const Block2ETileMap block_2_etile_map,
|
||||
index_t NRaw)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
#if defined(__gfx9__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>())
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_welford_mean_grid,
|
||||
p_welford_var_grid,
|
||||
p_welford_count_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
block_2_etile_map,
|
||||
NRaw);
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_welford_mean_grid,
|
||||
p_welford_var_grid,
|
||||
p_welford_count_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
block_2_etile_map,
|
||||
NRaw);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
|
||||
@@ -321,12 +321,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(arg.c_grid_desc_m_n_);
|
||||
|
||||
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
|
||||
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(arg.b_grid_desc_k0_n_k1_);
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
|
||||
@@ -335,8 +329,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
@@ -352,8 +345,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
@@ -367,8 +360,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
@@ -384,8 +376,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
|
||||
@@ -369,11 +369,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M
|
||||
return transform_tensor_descriptor(
|
||||
descriptor,
|
||||
make_tuple(make_right_pad_transform(descriptor, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return transform_tensor_descriptor(descriptor,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -616,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
|
||||
|
||||
// R pointer
|
||||
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
|
||||
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
|
||||
compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0];
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user