mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
use old ctile to avoid conv2d fwd bias relu add compute error (#271)
This commit is contained in:
@@ -460,6 +460,8 @@ struct
|
||||
using C0GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I3])>;
|
||||
using C1GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I4])>;
|
||||
|
||||
using Block2CTileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3<
|
||||
BlockSize,
|
||||
@@ -522,8 +524,6 @@ struct
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
@@ -540,10 +540,7 @@ struct
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
block_2_ctile_map_{},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
@@ -576,6 +573,8 @@ struct
|
||||
c0_grid_desc_m_n_ = descs[I3];
|
||||
c1_grid_desc_m_n_ = descs[I4];
|
||||
|
||||
block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_};
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
@@ -618,9 +617,7 @@ struct
|
||||
typename GridwiseGemm::
|
||||
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
@@ -723,7 +720,7 @@ struct
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
Block2CTileMap,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
@@ -767,7 +764,7 @@ struct
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
Block2CTileMap,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
@@ -894,8 +891,6 @@ struct
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
@@ -938,8 +933,6 @@ struct
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
Reference in New Issue
Block a user