mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK] Fix misc issues in CK examples (#2890)
* [CK] Fix misc CK issues * revert fp8 change, it causes CI fail. * resubmit fp8 change
This commit is contained in:
@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
@@ -32,17 +31,16 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_skip_b_lds_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N c_grid_desc_m_n,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -50,6 +48,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
|
||||
|
||||
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
|
||||
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -67,8 +71,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
|
||||
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = c_grid_desc_m_n;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
@@ -375,20 +379,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
|
||||
decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
template <bool HasMainK0BlockLoop,
|
||||
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
|
||||
Reference in New Issue
Block a user