mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
use old ctile to avoid conv2d fwd bias relu add compute error (#271)
[ROCm/composable_kernel commit: 1c5d06f270]
This commit is contained in:
@@ -224,10 +224,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
|
||||
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
|
||||
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
|
||||
@@ -460,6 +460,8 @@ struct
|
||||
using C0GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I3])>;
|
||||
using C1GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I4])>;
|
||||
|
||||
using Block2CTileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3<
|
||||
BlockSize,
|
||||
@@ -522,8 +524,6 @@ struct
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
@@ -540,10 +540,7 @@ struct
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
block_2_ctile_map_{},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
@@ -576,6 +573,8 @@ struct
|
||||
c0_grid_desc_m_n_ = descs[I3];
|
||||
c1_grid_desc_m_n_ = descs[I4];
|
||||
|
||||
block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_};
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
@@ -618,9 +617,7 @@ struct
|
||||
typename GridwiseGemm::
|
||||
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
@@ -723,7 +720,7 @@ struct
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
Block2CTileMap,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
@@ -767,7 +764,7 @@ struct
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
Block2CTileMap,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
@@ -894,8 +891,6 @@ struct
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
@@ -938,8 +933,6 @@ struct
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
@@ -340,7 +340,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
|
||||
Reference in New Issue
Block a user