From 376e0992ef791791ed34708bc32c4fcd402e30c3 Mon Sep 17 00:00:00 2001 From: joye Date: Tue, 10 Jun 2025 08:34:27 +0800 Subject: [PATCH] fix compiling errors; now can pass compilation --- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 63 ++++++++++--------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 60bfb963c9..9dc8dc4c3e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -74,28 +74,28 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOp a_element_op, - const BElementwiseOp b_element_op, - const CDEElementwiseOp cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t KBatch) + kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOp a_element_op, + const BElementwiseOp b_element_op, + const CDEElementwiseOp cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) { -#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); @@ -559,10 +559,10 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re int kernel_x = (in_x + 1) * up_w - mid_x - 1; int kernel_y = (in_y + 1) * up_h - mid_y - 1; - using ABDTypeVec_t = typename vector_type::type; - using EDataTypeVec_t = typename vector_type::type; + using ABDTypeVec_t = vector_type; + using EDataTypeVec_t = vector_type; - using ETypeDstVec_t = typename vector_type::type; + using ETypeDstVec_t = vector_type; EDataTypeVec_t v{}; @@ -584,11 +584,10 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re (kernel_x + x * up_w) * GroupPerBlock + group_out_id]; static_for<0, ElementPerInFP4, 1>{}([&](auto idx) { - float temp = v.AsType()[idx]; - inner_product(type_convert(shmem_x_vec.AsType()[idx]), - type_convert(shmem_k_vec.AsType()[idx]), - temp); - v.AsType()[idx] = temp; + auto x_val = shmem_x_vec.template AsType()[idx]; + auto k_val = shmem_k_vec.template AsType()[idx]; + + v.template AsType()(idx) += x_val * k_val; }); } @@ -601,8 +600,14 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re (batch_start_id_per_blk + batch_id + batch_id_per_wave) * outgrad_batch_stride + out_y * outgrad_row_stride + out_x * outgrad_col_stride; + ETypeDstVec_t output_vec; + static_for<0, ElementPerOutFP4, 1>{}([&](auto idx) { + output_vec.template AsType()(idx) = + type_convert(v.template AsType()[idx]); + }); + reinterpret_cast(p_gradIn)[outgrad_offset / ElementPerOutFP4] = - type_convert(v); + output_vec; } } }