Grouped convolution forward with clamp (#2334)

* Grouped convolution forward with clamp

* Optimize clamp

* unary fixes

* test gk bias

* Revert "test gk bias"

This reverts commit 8e42e29d7b.

* Revert "Revert "test gk bias""

This reverts commit e73c0550ce.

* workaround comment
This commit is contained in:
Bartłomiej Kocot
2025-06-16 15:36:53 +02:00
committed by GitHub
parent d996bc78be
commit f6c2ff9dce
41 changed files with 2103 additions and 106 deletions

View File

@@ -71,11 +71,13 @@ template <typename ADataType,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType_ = AComputeDataType_>
PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType_ = AComputeDataType_,
bool DoElementwiseBeforeCShuffle = false>
struct GridwiseGemmMultipleD_xdl_cshuffle
{
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0);
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
@@ -796,37 +798,60 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
tensor_operation::element_wise::PassThrough pass_through{};
const auto& vpgr_to_lds_element_op = [&] {
if constexpr(DoElementwiseBeforeCShuffle)
{
return cde_element_op;
}
else
{
return pass_through;
}
};
const auto& lds_to_global_element_op = [&] {
if constexpr(!DoElementwiseBeforeCShuffle)
{
return cde_element_op;
}
else
{
return pass_through;
}
};
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CDEElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
vpgr_to_lds_element_op()};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
@@ -860,7 +885,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
conditional_t<!DoElementwiseBeforeCShuffle,
CDEElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
@@ -881,7 +908,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
lds_to_global_element_op()};
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =

View File

@@ -186,6 +186,8 @@ __global__ void
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
/// @tparam DoElementwiseBeforeCShuffle Whether the cde_elementwise should be performed before or
/// after elementwise op.
template <typename ALayout,
typename BLayout,
typename CLayout,
@@ -233,7 +235,8 @@ template <typename ALayout,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
bool PermuteB = false,
bool DoElementwiseBeforeCShuffle = false>
struct GridwiseGemm_xdl_cshuffle_v3
{
static constexpr auto I0 = Number<0>{};
@@ -636,7 +639,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t KBatch_)
index_t KBatch_,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: M{M_},
N{N_},
K{K_},
@@ -651,7 +657,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
AK0{CalculateAK0Padded(K_, KBatch_)},
BK0{CalculateBK0Padded(K_, KBatch_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
NBlock{CalculateNBlock(N_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
@@ -689,6 +698,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t BK0;
index_t MBlock;
index_t NBlock;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Argument
@@ -704,8 +716,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideB_,
index_t StrideC_,
index_t k_batch_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
bool is_reduce_ = false,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{})
: Problem{M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
k_batch_,
a_element_op,
b_element_op,
c_element_op},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
@@ -1377,10 +1401,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -1440,7 +1460,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
problem.a_element_op_,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
@@ -1471,7 +1491,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
problem.b_element_op_,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
@@ -1598,42 +1618,67 @@ struct GridwiseGemm_xdl_cshuffle_v3
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
tensor_operation::element_wise::PassThrough pass_through{};
const auto& vpgr_to_lds_element_op = [&] {
if constexpr(DoElementwiseBeforeCShuffle)
{
return problem.c_element_op_;
}
else
{
return pass_through;
}
};
const auto& lds_to_global_element_op = [&] {
if constexpr(!DoElementwiseBeforeCShuffle)
{
return problem.c_element_op_;
}
else
{
return pass_through;
}
};
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
conditional_t<DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
vpgr_to_lds_element_op()};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
ThisThreadBlock, // ThreadGroup
conditional_t<!DoElementwiseBeforeCShuffle,
CElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
@@ -1654,7 +1699,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
lds_to_global_element_op()};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
@@ -1773,10 +1818,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -1836,7 +1877,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
problem.a_element_op_,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
@@ -1867,7 +1908,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
problem.b_element_op_,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
@@ -2059,7 +2100,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
problem.c_element_op_};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =