This commit is contained in:
subhajitdchow
2026-01-29 14:59:18 +00:00
parent a9fcb27ded
commit 165805cee7
5 changed files with 52 additions and 6 deletions

View File

@@ -1046,6 +1046,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
if(threadIdx.x == 0) {
printf("Repeat: (M N K): (%d, %d, %d)\n", m0.value, n0.value, k0.value);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
@@ -1055,11 +1059,14 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
make_tuple(n0, I0, k0, ik))>{}];
if(threadIdx.x == 0) {
printf("a: %f b: %f\n",
printf("a: %f b: %f a_off: %d b_off: %d\n",
static_cast<float>(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]),
static_cast<float>(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]));
make_tuple(n0, I0, k0, ik))>{}]),
a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik)),
b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))
);
}
});

View File

@@ -458,10 +458,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
ALayout,
BLayout,
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::RowMajor,
DsLayout,
ELayout,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,

View File

@@ -651,6 +651,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
Print();
}
__host__ void Print() const
@@ -932,6 +933,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
#endif
static_assert(BBlockTransferSrcVectorDim == 1, "should be 1 now!");
// B matrix in LDS memory, dst of blockwise copy
if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 2)
{
@@ -1692,6 +1695,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
if(threadIdx.x == 0) {
printf("a size aligned: %ld, a size: %ld b size: %ld\n", a_block_space_size_aligned.value, a_block_desc_ak0_m_ak1.GetElementSpaceSize().value, b_block_desc_bk0_n_bk1.GetElementSpaceSize().value);
}
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),

View File

@@ -69,7 +69,7 @@ using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple<
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, false>,
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, false>
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>,

View File

@@ -121,6 +121,38 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
break;
case 3:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{2});
break;
case 4:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{2});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
break;
case 5:
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
break;
case 6:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{0.0, 1.0});
break;
case 7:
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{0.0, 1.0});
break;
case 8:
out.GenerateTensorValue(GeneratorTensor_Sequential<OutDataType, 2>{});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
break;
case 9:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_Sequential<WeiDataType, 1>{});
break;
case 10:
out.GenerateTensorValue(GeneratorTensor_Sequential<OutDataType, 2>{});
wei.GenerateTensorValue(GeneratorTensor_Sequential<WeiDataType, 1>{});
break;
default:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});