update kernel

This commit is contained in:
joyeamd
2025-06-10 08:11:30 +08:00
parent d70acd683b
commit 4e469f5572

View File

@@ -436,8 +436,8 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
// WaveNum * GroupPerBlk * TileInH * TileInW
constexpr index_t GroupPerBlockInFP4 = GroupPerBlock / ElementPerInFP4;
__shared__ volatile ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G
__shared__ volatile ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock];
__shared__ ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G
__shared__ ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock];
// layout : B->H->W->G will use double buffer to go through BatchPerBlock
// when backward data, will be gradOut, when forward, will be x
@@ -468,7 +468,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
const int group_start_id_per_blk = (blockIdx.z % GroupBatchNum) * GroupPerBlock;
const int batch_start_id_per_blk = (blockIdx.z / GroupBatchNum) * BatchPerBlock;
int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize);
// int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize);
int tile_mid_x = TileOutW * down_w + up_w - 1 - pad_w;
int tile_mid_y = TileOutH * down_h + up_h - 1 - pad_h;
@@ -571,15 +571,25 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
#pragma unroll
for(int x = 0; x < kernelW / up_w; x++)
{
// v += shmem_x[rel_in_y + y][rel_in_x + x] *
// shmem_k[kernel_y + y * up_y][kernel_x + x * up_x];
v += reinterpret_cast<ABDTypeVec_t*>(
shmem_x)[(batch_id_per_wave * GroupPerBlockInFP4 * TileInH * TileInW) +
(rel_in_y + y) * TileInW * GroupPerBlockInFP4 +
(rel_in_x + x) * GroupPerBlockInFP4 + group_out_id] *
reinterpret_cast<ABDTypeVec_t*>(
shmem_k)[(kernel_y + y * up_h) * kernelW * GroupPerBlock +
(kernel_x + x * up_w) * GroupPerBlock + group_out_id];
// using ABDTypeVec_t = typename vector_type<ABDataType,
// ElementPerInFP4>::type; using EDataTypeVec_t = typename vector_type<float,
// ElementPerInFP4>::type;
ABDTypeVec_t shmem_x_vec = reinterpret_cast<ABDTypeVec_t*>(
shmem_x)[(batch_id_per_wave * GroupPerBlockInFP4 * TileInH * TileInW) +
(rel_in_y + y) * TileInW * GroupPerBlockInFP4 +
(rel_in_x + x) * GroupPerBlockInFP4 + group_out_id];
ABDTypeVec_t shmem_k_vec = reinterpret_cast<ABDTypeVec_t*>(
shmem_k)[(kernel_y + y * up_h) * kernelW * GroupPerBlock +
(kernel_x + x * up_w) * GroupPerBlock + group_out_id];
static_for<0, ElementPerInFP4, 1>{}([&](auto idx) {
float temp = v.AsType<float>()[idx];
inner_product(type_convert<float>(shmem_x_vec.AsType<ABDataType>()[idx]),
type_convert<float>(shmem_k_vec.AsType<ABDataType>()[idx]),
temp);
v.AsType<float>()[idx] = temp;
});
}
if(out_x < out_width & out_y < out_height)