mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
fix v1r3 output reorder bug
This commit is contained in:
@@ -359,19 +359,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
#if 0
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
#else
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
@@ -383,7 +370,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2>{});
|
||||
#endif
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
@@ -401,20 +387,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
#else
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 6, 3, 4, 5>{};
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
|
||||
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
@@ -428,8 +401,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#endif
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
@@ -446,19 +418,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
#if 0
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
#else
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<N / N1,
|
||||
N1,
|
||||
@@ -470,7 +429,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
W1,
|
||||
W2,
|
||||
W3>{});
|
||||
#endif
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
@@ -486,26 +444,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
#else
|
||||
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
|
||||
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
|
||||
@@ -520,8 +461,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#endif
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user