mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
[CK] Fix misc issues in CK examples (#2890)
* [CK] Fix misc CK issues
* revert fp8 change, it causes CI fail.
* resubmit fp8 change
[ROCm/composable_kernel commit: f076f207ce]
This commit is contained in:
@@ -62,29 +62,32 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const Block2ETileMap block_2_etile_map,
|
||||
index_t NRaw)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
#if defined(__gfx9__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>())
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_welford_mean_grid,
|
||||
p_welford_var_grid,
|
||||
p_welford_count_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
block_2_etile_map,
|
||||
NRaw);
|
||||
GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_welford_mean_grid,
|
||||
p_welford_var_grid,
|
||||
p_welford_count_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
mean_var_grid_desc_mblock_mperblock_nblock,
|
||||
count_grid_desc_mblock_mperblock_nblock,
|
||||
block_2_etile_map,
|
||||
NRaw);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
|
||||
@@ -321,12 +321,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(arg.c_grid_desc_m_n_);
|
||||
|
||||
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
|
||||
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(arg.b_grid_desc_k0_n_k1_);
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1<
|
||||
@@ -335,8 +329,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
@@ -352,8 +345,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
@@ -367,8 +360,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceGemmXdlSkipBLds::CGridDesc_M_N>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
@@ -384,8 +376,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
|
||||
@@ -369,11 +369,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M
|
||||
return transform_tensor_descriptor(
|
||||
descriptor,
|
||||
make_tuple(make_right_pad_transform(descriptor, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return transform_tensor_descriptor(descriptor,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -616,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
using RDataType = remove_cvref_t<tuple_element_t<i.value, RsDataType>>;
|
||||
|
||||
// R pointer
|
||||
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
|
||||
p_rs_grid_(i) = static_cast<RDataType*>(p_rs[i]);
|
||||
compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0];
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
@@ -32,17 +31,16 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_skip_b_lds_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N c_grid_desc_m_n,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__)
|
||||
@@ -50,6 +48,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
|
||||
|
||||
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
|
||||
GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -67,8 +71,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3;
|
||||
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = c_grid_desc_m_n;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
@@ -375,20 +379,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 =
|
||||
decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{}));
|
||||
|
||||
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
template <bool HasMainK0BlockLoop,
|
||||
typename BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
|
||||
@@ -18,14 +18,13 @@
|
||||
#define CK_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
|
||||
__HIP_DEVICE_COMPILE__
|
||||
#if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
|
||||
#define CK_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
#define CK_FP8_CVT_FAST_PATH 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
|
||||
#if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
|
||||
#define CK_OCP_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
#define CK_OCP_FP8_CVT_FAST_PATH 0
|
||||
@@ -390,7 +389,7 @@ struct bf8_ocp_t
|
||||
__host__ explicit operator float() const
|
||||
#endif
|
||||
{
|
||||
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<float, wm, we, false>(
|
||||
@@ -404,7 +403,7 @@ struct bf8_ocp_t
|
||||
__host__ explicit operator _Float16() const
|
||||
#endif
|
||||
{
|
||||
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
|
||||
|
||||
@@ -988,7 +988,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
// __builtin_amdgcn_cvt_pk_f32_fp8 can produce incorrect results due to a compiler issue.
|
||||
// TODO: Enable when SWDEV-532959 is fixed.
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx12__)
|
||||
return float2_t{__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 0),
|
||||
__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 1)};
|
||||
#else
|
||||
@@ -1131,7 +1131,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, bf8x2_ocp_t>(bf8x2_oc
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
// __builtin_amdgcn_cvt_pk_f32_bf8 can produce incorrect results due to a compiler issue.
|
||||
// TODO: Enable when SWDEV-532959 is fixed.
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx12__)
|
||||
return float2_t{__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 0),
|
||||
__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 1)};
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user