diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index d149fd88f1..d5c42558c4 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -36,7 +36,7 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; #else - < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; + < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; using ADataType = float; using BDataType = float; using CDataType = float; @@ -185,7 +185,6 @@ int main(int argc, char* argv[]) auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); @@ -209,8 +208,7 @@ int main(int argc, char* argv[]) return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 036f288d0a..7142521c55 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -125,7 +125,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); problem_size = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index c4e7068499..4b290d02a2 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -23,7 +23,7 @@ using RsGlobalReduceOp = static constexpr auto ConvSpec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off template diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index ce9f9b7032..ae5e3f36ad 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -65,7 +65,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| //######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>; + < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, S<8, 32>, 4>; // clang-format on auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { @@ -154,8 +154,8 @@ void host_gemm_layernorm(Tensor& h_m_n, int main() { - // temp disable on gfx11 & gfx12 - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + // temp disable on gfx11 + if(ck::is_gfx11_supported()) { return 0; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index 0abc30d7a2..52ecbeea6b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -62,29 +62,32 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const Block2ETileMap block_2_etile_map, index_t NRaw) { -#if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; +#if defined(__gfx9__) || defined(__gfx12__) + if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>()) + { + __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; - GridwiseGemmWelford::template Run( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_welford_mean_grid, - p_welford_var_grid, - p_welford_count_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - mean_var_grid_desc_mblock_mperblock_nblock, - count_grid_desc_mblock_mperblock_nblock, - block_2_etile_map, - NRaw); + GridwiseGemmWelford::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_welford_mean_grid, + p_welford_var_grid, + p_welford_count_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + mean_var_grid_desc_mblock_mperblock_nblock, + count_grid_desc_mblock_mperblock_nblock, + block_2_etile_map, + NRaw); + } #else ignore = p_a_grid; ignore = p_b_grid; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index bc192b7651..4abd14b080 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -321,12 +321,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm, remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -352,8 +345,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm, remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -384,8 +376,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm{}), - make_tuple(Sequence<0>{})); + return transform_tensor_descriptor(descriptor, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); } else { @@ -616,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle using RDataType = remove_cvref_t>; // R pointer - p_rs_grid_(i) = static_cast(p_rs[i]); + p_rs_grid_(i) = static_cast(p_rs[i]); + compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0]; }); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 9e524c5a23..cf3040d1ae 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -21,8 +21,7 @@ template (p_a_grid, p_b_grid, p_c_grid, @@ -67,8 +71,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = p_b_grid; ignore = p_c_grid; ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m_n; ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; @@ -375,20 +379,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 return cblockid_to_m0_n0_block_cluster_adaptor; } - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 = - decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{})); - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 0b73f76155..2c00f4f42f 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -18,14 +18,13 @@ #define CK_USE_OCP_FP8 0 #endif -#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \ - __HIP_DEVICE_COMPILE__ +#if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__ #define CK_FP8_CVT_FAST_PATH 1 #else #define CK_FP8_CVT_FAST_PATH 0 #endif -#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ +#if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__ #define CK_OCP_FP8_CVT_FAST_PATH 1 #else #define CK_OCP_FP8_CVT_FAST_PATH 0 @@ -390,7 +389,7 @@ struct bf8_ocp_t __host__ explicit operator float() const #endif { -#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx12__) return fp8_impl::cast_to_f32_from_f8(this->data); #else return fp8_impl::cast_from_f8( @@ -404,7 +403,7 @@ struct bf8_ocp_t __host__ explicit operator _Float16() const #endif { -#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx12__) return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); #else return fp8_impl::cast_from_f8<_Float16, wm, we, false>( diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 66d760c2b3..701b2686c7 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -988,7 +988,7 @@ inline __host__ __device__ float2_t type_convert(f8x2_ocp_ #if CK_OCP_FP8_CVT_FAST_PATH // __builtin_amdgcn_cvt_pk_f32_fp8 can produce incorrect results due to a compiler issue. // TODO: Enable when SWDEV-532959 is fixed. -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx12__) return float2_t{__builtin_amdgcn_cvt_f32_fp8(bit_cast(x), 0), __builtin_amdgcn_cvt_f32_fp8(bit_cast(x), 1)}; #else @@ -1131,7 +1131,7 @@ inline __host__ __device__ float2_t type_convert(bf8x2_oc #if CK_OCP_FP8_CVT_FAST_PATH // __builtin_amdgcn_cvt_pk_f32_bf8 can produce incorrect results due to a compiler issue. // TODO: Enable when SWDEV-532959 is fixed. -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx12__) return float2_t{__builtin_amdgcn_cvt_f32_bf8(bit_cast(x), 0), __builtin_amdgcn_cvt_f32_bf8(bit_cast(x), 1)}; #else