diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 9bdd941074..60bfb963c9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -436,8 +436,8 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re // WaveNum * GroupPerBlk * TileInH * TileInW constexpr index_t GroupPerBlockInFP4 = GroupPerBlock / ElementPerInFP4; - __shared__ volatile ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G - __shared__ volatile ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock]; + __shared__ ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G + __shared__ ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock]; // layout : B->H->W->G will use double buffer to go through BatchPerBlock // when backward data, will be gradOut, when forward, will be x @@ -468,7 +468,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re const int group_start_id_per_blk = (blockIdx.z % GroupBatchNum) * GroupPerBlock; const int batch_start_id_per_blk = (blockIdx.z / GroupBatchNum) * BatchPerBlock; - int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize); + // int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize); int tile_mid_x = TileOutW * down_w + up_w - 1 - pad_w; int tile_mid_y = TileOutH * down_h + up_h - 1 - pad_h; @@ -571,15 +571,25 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re #pragma unroll for(int x = 0; x < kernelW / up_w; x++) { - // v += shmem_x[rel_in_y + y][rel_in_x + x] * - // shmem_k[kernel_y + y * up_y][kernel_x + x * up_x]; - v += reinterpret_cast( - shmem_x)[(batch_id_per_wave * GroupPerBlockInFP4 * TileInH * TileInW) + - (rel_in_y + y) * TileInW * GroupPerBlockInFP4 + - (rel_in_x + x) * GroupPerBlockInFP4 + group_out_id] * - reinterpret_cast( - shmem_k)[(kernel_y + y * up_h) * kernelW * GroupPerBlock + - (kernel_x + x * up_w) * GroupPerBlock + group_out_id]; + // using ABDTypeVec_t = typename vector_type::type; using EDataTypeVec_t = typename vector_type::type; + + ABDTypeVec_t shmem_x_vec = reinterpret_cast( + shmem_x)[(batch_id_per_wave * GroupPerBlockInFP4 * TileInH * TileInW) + + (rel_in_y + y) * TileInW * GroupPerBlockInFP4 + + (rel_in_x + x) * GroupPerBlockInFP4 + group_out_id]; + ABDTypeVec_t shmem_k_vec = reinterpret_cast( + shmem_k)[(kernel_y + y * up_h) * kernelW * GroupPerBlock + + (kernel_x + x * up_w) * GroupPerBlock + group_out_id]; + + static_for<0, ElementPerInFP4, 1>{}([&](auto idx) { + float temp = v.AsType()[idx]; + inner_product(type_convert(shmem_x_vec.AsType()[idx]), + type_convert(shmem_k_vec.AsType()[idx]), + temp); + v.AsType()[idx] = temp; + }); } if(out_x < out_width & out_y < out_height)