mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix compiling errors; now can pass compilation
This commit is contained in:
@@ -74,28 +74,28 @@ template <typename GridwiseGemm,
|
||||
InMemoryDataOperationEnum OutElementOp>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOp a_element_op,
|
||||
const BElementwiseOp b_element_op,
|
||||
const CDEElementwiseOp cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
|
||||
const index_t KBatch)
|
||||
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOp a_element_op,
|
||||
const BElementwiseOp b_element_op,
|
||||
const CDEElementwiseOp cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
|
||||
const index_t KBatch)
|
||||
{
|
||||
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
|
||||
@@ -559,10 +559,10 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
|
||||
int kernel_x = (in_x + 1) * up_w - mid_x - 1;
|
||||
int kernel_y = (in_y + 1) * up_h - mid_y - 1;
|
||||
|
||||
using ABDTypeVec_t = typename vector_type<ABDataType, ElementPerInFP4>::type;
|
||||
using EDataTypeVec_t = typename vector_type<float, ElementPerInFP4>::type;
|
||||
using ABDTypeVec_t = vector_type<ABDataType, ElementPerInFP4>;
|
||||
using EDataTypeVec_t = vector_type<float, ElementPerInFP4>;
|
||||
|
||||
using ETypeDstVec_t = typename vector_type<EDataType, ElementPerOutFP4>::type;
|
||||
using ETypeDstVec_t = vector_type<EDataType, ElementPerOutFP4>;
|
||||
|
||||
EDataTypeVec_t v{};
|
||||
|
||||
@@ -584,11 +584,10 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
|
||||
(kernel_x + x * up_w) * GroupPerBlock + group_out_id];
|
||||
|
||||
static_for<0, ElementPerInFP4, 1>{}([&](auto idx) {
|
||||
float temp = v.AsType<float>()[idx];
|
||||
inner_product(type_convert<float>(shmem_x_vec.AsType<ABDataType>()[idx]),
|
||||
type_convert<float>(shmem_k_vec.AsType<ABDataType>()[idx]),
|
||||
temp);
|
||||
v.AsType<float>()[idx] = temp;
|
||||
auto x_val = shmem_x_vec.template AsType<ABDataType>()[idx];
|
||||
auto k_val = shmem_k_vec.template AsType<ABDataType>()[idx];
|
||||
|
||||
v.template AsType<float>()(idx) += x_val * k_val;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -601,8 +600,14 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
|
||||
(batch_start_id_per_blk + batch_id + batch_id_per_wave) * outgrad_batch_stride +
|
||||
out_y * outgrad_row_stride + out_x * outgrad_col_stride;
|
||||
|
||||
ETypeDstVec_t output_vec;
|
||||
static_for<0, ElementPerOutFP4, 1>{}([&](auto idx) {
|
||||
output_vec.template AsType<EDataType>()(idx) =
|
||||
type_convert<EDataType>(v.template AsType<float>()[idx]);
|
||||
});
|
||||
|
||||
reinterpret_cast<ETypeDstVec_t*>(p_gradIn)[outgrad_offset / ElementPerOutFP4] =
|
||||
type_convert<ETypeDstVec_t>(v);
|
||||
output_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user