Refactor block to C tile map (#235)

* refactor block-to-ctile-map

* gridwise gemm block2ctile generic validity check

* format

* amend split-k gemm block2ctile map refactor

* add test

* format

* amend

* revert to calculating batch index in kernel instead of passing as block_id_z

* move file

* add valid ctile index check to gridwise v2r4
This commit is contained in:
Anthony Chang
2022-05-21 01:40:51 +08:00
committed by GitHub
parent 070619fbf1
commit a054f7d604
33 changed files with 770 additions and 743 deletions

View File

@@ -470,44 +470,6 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
static constexpr auto MakeBlock2CTileMap(index_t batch_count,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(batch_count),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
@@ -608,8 +570,6 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopSched>;
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
// Argument
struct Argument : public BaseArgument
{
@@ -645,15 +605,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -661,8 +623,6 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount, c_grid_desc_m_n_, 1, 1);
}
}
@@ -680,7 +640,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
@@ -717,14 +677,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
@@ -747,7 +709,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch,
remove_reference_t<Block2CTileMap>,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
elapsed_time =
@@ -790,7 +752,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch,
remove_reference_t<Block2CTileMap>,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
elapsed_time =
@@ -836,8 +798,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -243,44 +243,6 @@ struct DeviceBatchedGemmXdl
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static constexpr auto MakeBlock2CTileMap(index_t batch_count,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(batch_count),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
@@ -354,7 +316,7 @@ struct DeviceBatchedGemmXdl
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Argument
struct Argument : public BaseArgument
@@ -388,20 +350,21 @@ struct DeviceBatchedGemmXdl
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{},
block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount, c_grid_desc_m_n_, M01, N01);
}
}
@@ -446,15 +409,14 @@ struct DeviceBatchedGemmXdl
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -552,8 +514,7 @@ struct DeviceBatchedGemmXdl
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -433,17 +433,16 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
M01_,
N01_))
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
}
}
@@ -502,14 +501,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting");
}
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
@@ -659,8 +658,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override

View File

@@ -486,13 +486,16 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_))
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
@@ -572,15 +575,14 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_container_[i]))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
@@ -703,8 +705,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_container_[i]))
{
return false;
}

View File

@@ -540,7 +540,8 @@ struct
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
block_2_ctile_map_{},
block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01},
N01_{N01},
in_element_op_{in_element_op},
@@ -575,8 +576,10 @@ struct
c0_grid_desc_m_n_ = descs[I3];
c1_grid_desc_m_n_ = descs[I4];
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
@@ -592,9 +595,6 @@ struct
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -689,14 +689,14 @@ struct
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -852,8 +852,7 @@ struct
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override

View File

@@ -548,9 +548,13 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
c0_grid_desc_m_n_ = descs[I3];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
@@ -561,9 +565,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -649,14 +650,14 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -802,8 +803,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override

View File

@@ -520,18 +520,20 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
c_grid_desc_m_n_ = descs[I2];
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -631,14 +633,14 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -774,8 +776,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override

View File

@@ -408,15 +408,16 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -469,14 +470,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -606,8 +607,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override

View File

@@ -259,50 +259,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
struct Block2CTileMapMaker
{
Block2CTileMapMaker(index_t num_batches) : num_batches_(num_batches) {}
__host__ __device__ constexpr auto
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(num_batches_),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(num_batches_, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
private:
index_t num_batches_;
};
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
InDataType,
@@ -345,8 +301,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(Block2CTileMapMaker{1}.MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Argument
struct Argument : public BaseArgument
@@ -398,18 +353,20 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
b_batch_stride_ = 0;
c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize();
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = Block2CTileMapMaker{num_subbatches_}.MakeBlock2CTileMap(
c_grid_desc_m_n_, M01, N01);
}
}
@@ -457,16 +414,15 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
// todo: grid_size times arg.num_subbatches_
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.num_subbatches_;
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) *
arg.num_subbatches_;
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
@@ -565,8 +521,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -1073,13 +1073,15 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_))
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
@@ -1129,13 +1131,16 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_))
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
@@ -1194,14 +1199,17 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_))
auto block_2_ctile_map =
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
if(GridwiseGemm::CheckValidity(
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
block_2_ctile_map_container_.push_back(block_2_ctile_map);
}
}
}
@@ -1286,15 +1294,14 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_container_[i]))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
@@ -1418,8 +1425,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_container_[i]))
{
return false;
}
@@ -1528,10 +1534,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
<< ">";
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){
str<< " Filter1x1Stride1Pad0";
}
return str.str();
}

View File

@@ -705,15 +705,16 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -766,14 +767,14 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -916,8 +917,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
@@ -1012,7 +1012,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
auto str = std::stringstream();
// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< BlockSize << ", "

View File

@@ -460,15 +460,17 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -476,8 +478,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
@@ -528,13 +528,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
@@ -640,8 +643,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -257,14 +257,16 @@ struct DeviceGemmXdl
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -310,14 +312,14 @@ struct DeviceGemmXdl
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -409,8 +411,7 @@ struct DeviceGemmXdl
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -218,8 +218,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
c_grid_desc_m_n_ =
DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
@@ -230,9 +235,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -285,14 +287,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -400,8 +402,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -227,8 +227,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
c_grid_desc_m_n_ = descs[I2];
c0_grid_desc_m_n_ = descs[I3];
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
@@ -239,9 +244,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c0_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -294,14 +296,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -409,8 +411,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -256,8 +256,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
c0_grid_desc_m_n_ = descs[I3];
c1_grid_desc_m_n_ = descs[I4];
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
GridwiseGemm::
@@ -273,9 +278,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
GridwiseGemm::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c1_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
@@ -336,14 +338,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
@@ -461,8 +463,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -404,19 +404,19 @@ struct DeviceGemm_Xdl_CShuffle
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
@@ -459,13 +459,16 @@ struct DeviceGemm_Xdl_CShuffle
}
#endif
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
@@ -555,8 +558,10 @@ struct DeviceGemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -332,17 +332,16 @@ struct DeviceGemmXdlSplitK
K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
M01_,
N01_))
block_2_ctile_map_))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
}
}
@@ -395,14 +394,14 @@ struct DeviceGemmXdlSplitK
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
@@ -532,8 +531,7 @@ struct DeviceGemmXdlSplitK
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -292,8 +292,7 @@ struct DeviceGemmXdlSplitKCShuffle
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor;
// Argument
struct Argument : public BaseArgument
@@ -338,17 +337,16 @@ struct DeviceGemmXdlSplitKCShuffle
K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
M01_,
N01_))
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
}
}
@@ -401,14 +399,14 @@ struct DeviceGemmXdlSplitKCShuffle
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
@@ -541,8 +539,7 @@ struct DeviceGemmXdlSplitKCShuffle
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
arg.block_2_ctile_map_);
}
// polymorphic

View File

@@ -307,6 +307,11 @@ struct DeviceGroupedGemmXdl
struct GroupedGemmBlock2CTileMap
{
using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
static_assert(
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)),
typename GridwiseGemm::DefaultBlock2CTileMap>::value,
"Wrong! Should be the same type name");
GroupedGemmBlock2CTileMap()
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1);
@@ -329,6 +334,18 @@ struct DeviceGroupedGemmXdl
make_multi_index(idx_top[I0] - BlockStart_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_2_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
private:
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
ck::index_t BlockStart_;
@@ -400,22 +417,27 @@ struct DeviceGroupedGemmXdl
const auto c_grid_desc_m_n_ =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_);
const index_t grid_size_grp =
typename GroupedGemmBlock2CTileMap::UnderlyingBlock2CTileMap(
c_grid_desc_m_n_, M01, N01)
.CalculateGridSize(c_grid_desc_m_n_);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
const auto grouped_gemm_block_2_ctile_map_ =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
grouped_gemm_block_2_ctile_map_))
{
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
const auto grouped_gemm_block_2_ctile_map_ =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart);
gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
@@ -475,11 +497,11 @@ struct DeviceGroupedGemmXdl
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
if(!GridwiseGemm::CheckValidity(
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");

View File

@@ -0,0 +1,258 @@
#ifndef UTILITY_BLOCK_TO_CTILE_MAP
#define UTILITY_BLOCK_TO_CTILE_MAP
#include "utility/math.hpp"
#include "utility/number.hpp"
#include "tensor_description/tensor_adaptor.hpp"
#include "tensor_description/multi_index_transform_helper.hpp"
namespace ck {
// Blocks of row-vectors
template <index_t MPerBlock,
index_t NPerBlock,
typename CGridDesc_M_N,
bool DeviceCTileIndexCheck = false>
struct BlockToCTileMap_M00_N00_M01_N01
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default;
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 1,
index_t N01 = 1)
: M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01))
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01_);
const auto N00 = math::integer_divide_ceil(N0, N01_);
const index_t grid_size = M00 * M01_ * N00 * N01_;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return underlying_map_.CalculateBottomIndex(idx_top);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
if constexpr(DeviceCTileIndexCheck)
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
else
return true;
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel
const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
if(M0 % M01_ == 0 && N0 % N01_ == 0)
{
return true;
}
else
{
return false;
}
}
private:
__host__ __device__ static constexpr auto
GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01);
const auto N00 = math::integer_divide_ceil(N0, N01);
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
index_t M01_, N01_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1));
UnderlyingMap underlying_map_;
};
// 2D slices of row-vectors in 3D space
template <index_t MPerBlock,
index_t NPerBlock,
typename CGridDesc_M_N,
bool DeviceCTileIndexCheck = false>
struct BlockToCTileMap_KSplit_M00_N00_M01_N01
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01() = default;
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 1,
index_t N01 = 1,
index_t KSplit = 1)
: M01_(M01),
N01_(N01),
KSplit_(KSplit),
underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01_);
const auto N00 = math::integer_divide_ceil(N0, N01_);
const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return underlying_map_.CalculateBottomIndex(idx_top);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
if constexpr(DeviceCTileIndexCheck)
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
else
return true;
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel
const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
if(M0 % M01_ == 0 && N0 % N01_ == 0)
{
return true;
}
else
{
return false;
}
}
private:
__host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01,
index_t KSplit)
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01);
const auto N00 = math::integer_divide_ceil(N0, N01);
const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KSplit),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
}
index_t M01_, N01_, KSplit_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
UnderlyingMap underlying_map_;
};
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim)
{
bool is_valid = false;
const index_t m_block = c_tile_dim[Number<0>{}];
const index_t n_block = c_tile_dim[Number<1>{}];
if constexpr(CTileIdx::Size() == 2)
{
const index_t m_block_idx = c_tile_idx[Number<0>{}];
const index_t n_block_idx = c_tile_idx[Number<1>{}];
if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
{
is_valid = true;
}
}
else if constexpr(CTileIdx::Size() == 3)
{
const index_t ksplit_idx = c_tile_idx[Number<0>{}];
const index_t m_block_idx = c_tile_idx[Number<1>{}];
const index_t n_block_idx = c_tile_idx[Number<2>{}];
if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
{
is_valid = true;
}
ignore = ksplit_idx;
}
return is_valid;
}
} // namespace ck
#endif // UTILITY_BLOCK_TO_CTILE_MAP

View File

@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -218,10 +219,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n)
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
@@ -249,21 +252,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
@@ -309,40 +306,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
// FIXME: remove
constexpr auto M01 = I1;
constexpr auto N01 = I1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
@@ -383,6 +348,14 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);

View File

@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -190,10 +191,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n)
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
@@ -217,21 +220,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
@@ -262,40 +259,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
// FIXME: remove
constexpr auto M01 = I1;
constexpr auto N01 = I1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
@@ -329,6 +294,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);

View File

@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
@@ -185,12 +186,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
const Block2CTileMap& block_2_ctile_map)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
@@ -219,31 +220,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1);
@@ -305,36 +290,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n, M01, N01);
}
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
@@ -368,6 +325,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);

View File

@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
@@ -167,12 +168,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
index_t M01,
index_t N01)
const Block2CTileMap& block_2_ctile_map)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
@@ -196,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = K0 > K0PerBlock;
@@ -282,37 +267,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KBatch),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_kbatch_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_kbatch_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>(
c_m_n_grid_desc, M01, N01, KBatch);
}
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
@@ -344,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1))))
{
return;
}
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR

View File

@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -174,12 +175,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
index_t M01,
index_t N01)
const Block2CTileMap& block_2_ctile_map)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
@@ -203,31 +204,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = K0 > K0PerBlock;
@@ -256,37 +241,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KBatch),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>(
c_m_n_grid_desc, M01, N01, KBatch);
}
__host__ __device__ static constexpr auto
@@ -333,6 +289,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!c_block_cluster_adaptor.ValidCTileIndex(
make_tuple(block_work_idx[I1], block_work_idx[I2]),
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t k_batch_id = block_work_idx[I0];
// HACK: this force m/n_block_data_idx_on_grid into SGPR

View File

@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -223,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
const Block2CTileMap& block_2_ctile_map)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
@@ -256,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
@@ -318,36 +303,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n, M01, N01);
}
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
@@ -385,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);

View File

@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r2.hpp"
@@ -230,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
const Block2CTileMap& block_2_ctile_map)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
@@ -264,31 +265,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1);
@@ -327,37 +312,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n, M01, N01);
}
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
@@ -408,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);

View File

@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
@@ -237,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
const Block2CTileMap& block_2_ctile_map)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
@@ -271,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return false;
}
// check M01, N01
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
if(!(M0 % M01 == 0 && N0 % N01 == 0))
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1);
@@ -334,36 +319,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n, M01, N01);
}
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
@@ -427,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I0),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.GetLength(I3))))
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);