fix compiling errors

This commit is contained in:
joyeamd
2025-06-09 20:44:03 +08:00
parent 9872d2e159
commit 4c4dd342ed

View File

@@ -418,7 +418,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
const index_t in_width,
const index_t in_height,
const index_t group_num,
const index_t whole_batch_num)
const index_t batch_num)
{
constexpr index_t ElementPerInFP4 = 16 / sizeof(ABDataType);
constexpr index_t ElementPerOutFP4 = 16 / sizeof(EDataType);
@@ -580,7 +580,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
(kernel_x + x * up_w) * GroupPerBlock + group_out_id];
}
if(out_x < p.out_w & out_y < p.out_h)
if(out_x < out_width & out_y < out_height)
{
// global outgrad layout : NHWGK; shared outgrad layout : H->W->G
int outgrad_offset =
@@ -589,7 +589,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __re
(batch_start_id_per_blk + batch_id + batch_id_per_wave) * outgrad_batch_stride +
out_y * outgrad_row_stride + out_x * outgrad_col_stride;
reinterpret_cast<ETypeDstVec_t*>(p_gradOut)[outgrad_offset / ElementPerFP4] =
reinterpret_cast<ETypeDstVec_t*>(p_gradIn)[outgrad_offset / ElementPerInFP4] =
type_convert<ETypeDstVec_t>(v);
}
}
@@ -1379,7 +1379,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
auto launch_kernel = [&]() {
// constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr index_t GroupPerBlock = 64;
constexpr index_t GroupPerBlock = 32;
constexpr index_t BatchPerBlock = 8;
constexpr index_t BlockDim = 512;