fix compiling errors; now can pass compilation

This commit is contained in:
joye
2025-06-10 08:34:27 +08:00
parent 4e469f5572
commit 376e0992ef

View File

@@ -74,28 +74,28 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum OutElementOp>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOp a_element_op,
const BElementwiseOp b_element_op,
const CDEElementwiseOp cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t KBatch)
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOp a_element_op,
const BElementwiseOp b_element_op,
const CDEElementwiseOp cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t KBatch)
{
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
@@ -559,10 +559,10 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
int kernel_x = (in_x + 1) * up_w - mid_x - 1;
int kernel_y = (in_y + 1) * up_h - mid_y - 1;
using ABDTypeVec_t = typename vector_type<ABDataType, ElementPerInFP4>::type;
using EDataTypeVec_t = typename vector_type<float, ElementPerInFP4>::type;
using ABDTypeVec_t = vector_type<ABDataType, ElementPerInFP4>;
using EDataTypeVec_t = vector_type<float, ElementPerInFP4>;
using ETypeDstVec_t = typename vector_type<EDataType, ElementPerOutFP4>::type;
using ETypeDstVec_t = vector_type<EDataType, ElementPerOutFP4>;
EDataTypeVec_t v{};
@@ -584,11 +584,10 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
(kernel_x + x * up_w) * GroupPerBlock + group_out_id];
static_for<0, ElementPerInFP4, 1>{}([&](auto idx) {
float temp = v.AsType<float>()[idx];
inner_product(type_convert<float>(shmem_x_vec.AsType<ABDataType>()[idx]),
type_convert<float>(shmem_k_vec.AsType<ABDataType>()[idx]),
temp);
v.AsType<float>()[idx] = temp;
auto x_val = shmem_x_vec.template AsType<ABDataType>()[idx];
auto k_val = shmem_k_vec.template AsType<ABDataType>()[idx];
v.template AsType<float>()(idx) += x_val * k_val;
});
}
@@ -601,8 +600,14 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
(batch_start_id_per_blk + batch_id + batch_id_per_wave) * outgrad_batch_stride +
out_y * outgrad_row_stride + out_x * outgrad_col_stride;
ETypeDstVec_t output_vec;
static_for<0, ElementPerOutFP4, 1>{}([&](auto idx) {
output_vec.template AsType<EDataType>()(idx) =
type_convert<EDataType>(v.template AsType<float>()[idx]);
});
reinterpret_cast<ETypeDstVec_t*>(p_gradIn)[outgrad_offset / ElementPerOutFP4] =
type_convert<ETypeDstVec_t>(v);
output_vec;
}
}
}