mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
update kernel
This commit is contained in:
@@ -436,8 +436,8 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
|
||||
// WaveNum * GroupPerBlk * TileInH * TileInW
|
||||
constexpr index_t GroupPerBlockInFP4 = GroupPerBlock / ElementPerInFP4;
|
||||
|
||||
__shared__ volatile ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G
|
||||
__shared__ volatile ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock];
|
||||
__shared__ ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G
|
||||
__shared__ ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock];
|
||||
// layout : B->H->W->G will use double buffer to go through BatchPerBlock
|
||||
// when backward data, will be gradOut, when forward, will be x
|
||||
|
||||
@@ -468,7 +468,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
|
||||
const int group_start_id_per_blk = (blockIdx.z % GroupBatchNum) * GroupPerBlock;
|
||||
const int batch_start_id_per_blk = (blockIdx.z / GroupBatchNum) * BatchPerBlock;
|
||||
|
||||
int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize);
|
||||
// int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize);
|
||||
|
||||
int tile_mid_x = TileOutW * down_w + up_w - 1 - pad_w;
|
||||
int tile_mid_y = TileOutH * down_h + up_h - 1 - pad_h;
|
||||
@@ -571,15 +571,25 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
|
||||
#pragma unroll
|
||||
for(int x = 0; x < kernelW / up_w; x++)
|
||||
{
|
||||
// v += shmem_x[rel_in_y + y][rel_in_x + x] *
|
||||
// shmem_k[kernel_y + y * up_y][kernel_x + x * up_x];
|
||||
v += reinterpret_cast<ABDTypeVec_t*>(
|
||||
shmem_x)[(batch_id_per_wave * GroupPerBlockInFP4 * TileInH * TileInW) +
|
||||
(rel_in_y + y) * TileInW * GroupPerBlockInFP4 +
|
||||
(rel_in_x + x) * GroupPerBlockInFP4 + group_out_id] *
|
||||
reinterpret_cast<ABDTypeVec_t*>(
|
||||
shmem_k)[(kernel_y + y * up_h) * kernelW * GroupPerBlock +
|
||||
(kernel_x + x * up_w) * GroupPerBlock + group_out_id];
|
||||
// using ABDTypeVec_t = typename vector_type<ABDataType,
|
||||
// ElementPerInFP4>::type; using EDataTypeVec_t = typename vector_type<float,
|
||||
// ElementPerInFP4>::type;
|
||||
|
||||
ABDTypeVec_t shmem_x_vec = reinterpret_cast<ABDTypeVec_t*>(
|
||||
shmem_x)[(batch_id_per_wave * GroupPerBlockInFP4 * TileInH * TileInW) +
|
||||
(rel_in_y + y) * TileInW * GroupPerBlockInFP4 +
|
||||
(rel_in_x + x) * GroupPerBlockInFP4 + group_out_id];
|
||||
ABDTypeVec_t shmem_k_vec = reinterpret_cast<ABDTypeVec_t*>(
|
||||
shmem_k)[(kernel_y + y * up_h) * kernelW * GroupPerBlock +
|
||||
(kernel_x + x * up_w) * GroupPerBlock + group_out_id];
|
||||
|
||||
static_for<0, ElementPerInFP4, 1>{}([&](auto idx) {
|
||||
float temp = v.AsType<float>()[idx];
|
||||
inner_product(type_convert<float>(shmem_x_vec.AsType<ABDataType>()[idx]),
|
||||
type_convert<float>(shmem_k_vec.AsType<ABDataType>()[idx]),
|
||||
temp);
|
||||
v.AsType<float>()[idx] = temp;
|
||||
});
|
||||
}
|
||||
|
||||
if(out_x < out_width & out_y < out_height)
|
||||
|
||||
Reference in New Issue
Block a user