mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
tmp save
This commit is contained in:
@@ -1046,6 +1046,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
printf("Repeat: (M N K): (%d, %d, %d)\n", m0.value, n0.value, k0.value);
|
||||
}
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
@@ -1055,11 +1059,14 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
printf("a: %f b: %f\n",
|
||||
printf("a: %f b: %f a_off: %d b_off: %d\n",
|
||||
static_cast<float>(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}]),
|
||||
static_cast<float>(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}]));
|
||||
make_tuple(n0, I0, k0, ik))>{}]),
|
||||
a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik)),
|
||||
b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -458,10 +458,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
|
||||
@@ -651,6 +651,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)}
|
||||
{
|
||||
Print();
|
||||
}
|
||||
|
||||
__host__ void Print() const
|
||||
@@ -932,6 +933,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
|
||||
#endif
|
||||
|
||||
static_assert(BBlockTransferSrcVectorDim == 1, "should be 1 now!");
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
@@ -1692,6 +1695,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
if(threadIdx.x == 0) {
|
||||
printf("a size aligned: %ld, a size: %ld b size: %ld\n", a_block_space_size_aligned.value, a_block_desc_ak0_m_ak1.GetElementSpaceSize().value, b_block_desc_bk0_n_bk1.GetElementSpaceSize().value);
|
||||
}
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeB*>(p_shared) +
|
||||
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
|
||||
|
||||
@@ -69,7 +69,7 @@ using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple<
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, false>,
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
|
||||
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, false>
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
|
||||
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>,
|
||||
//DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<1,1,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true>
|
||||
// DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>,
|
||||
|
||||
@@ -121,6 +121,38 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{2});
|
||||
break;
|
||||
case 4:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{2});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
break;
|
||||
case 5:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
break;
|
||||
case 6:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{0.0, 1.0});
|
||||
break;
|
||||
case 7:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{0.0, 1.0});
|
||||
break;
|
||||
case 8:
|
||||
out.GenerateTensorValue(GeneratorTensor_Sequential<OutDataType, 2>{});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
break;
|
||||
case 9:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei.GenerateTensorValue(GeneratorTensor_Sequential<WeiDataType, 1>{});
|
||||
break;
|
||||
case 10:
|
||||
out.GenerateTensorValue(GeneratorTensor_Sequential<OutDataType, 2>{});
|
||||
wei.GenerateTensorValue(GeneratorTensor_Sequential<WeiDataType, 1>{});
|
||||
break;
|
||||
default:
|
||||
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
|
||||
Reference in New Issue
Block a user