mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Add support for groups in Img2Col/Col2Img (#1007)
* Add support for groups in Img2Col/Col2Img * Fix interface test * Fix interface test G to N * Improve performance * Change gemm layout to 3d * Fixes
This commit is contained in:
@@ -21,6 +21,7 @@ template <typename InputGridDesc,
|
||||
typename OutputGridDesc,
|
||||
typename OutputDataType,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfStridedBatch,
|
||||
typename GridwiseTensorRearrangeKernel>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -30,13 +31,20 @@ __global__ void
|
||||
const InputDataType* __restrict__ p_in_global,
|
||||
const OutputGridDesc out_grid_desc,
|
||||
OutputDataType* __restrict__ p_out_global,
|
||||
const Block2ETileMap block_2_tile_map)
|
||||
const index_t batch_count,
|
||||
const Block2ETileMap block_2_tile_map,
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
GridwiseTensorRearrangeKernel::Run(
|
||||
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
|
||||
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
|
||||
p_in_global,
|
||||
out_grid_desc,
|
||||
p_out_global,
|
||||
batch_count,
|
||||
block_2_tile_map,
|
||||
compute_ptr_offset_of_batch);
|
||||
#else
|
||||
ignore = in_grid_desc;
|
||||
ignore = p_in_global;
|
||||
@@ -56,7 +64,8 @@ template <typename InputGridDesc,
|
||||
typename ThreadClusterLengths,
|
||||
index_t ScalarPerVector,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename Block2ETileMap>
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfStridedBatch>
|
||||
struct GridwiseTensorRearrange
|
||||
{
|
||||
|
||||
@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
|
||||
const InputDataType* __restrict__ p_in_global,
|
||||
const OutputGridDesc& out_grid_desc,
|
||||
OutputDataType* __restrict__ p_out_global,
|
||||
const Block2ETileMap& block_2_tile_map)
|
||||
const index_t batch_count,
|
||||
const Block2ETileMap& block_2_tile_map,
|
||||
const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch)
|
||||
{
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
|
||||
const index_t k_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock);
|
||||
|
||||
// Global Memory
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_desc.GetElementSpaceSize());
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global, out_grid_desc.GetElementSpaceSize());
|
||||
|
||||
auto copy_global_to_global =
|
||||
ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock,
|
||||
Tuple<InputDataType>,
|
||||
@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
|
||||
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx =
|
||||
__builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
// Global Memory
|
||||
const index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const index_t c_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
|
||||
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize());
|
||||
|
||||
copy_global_to_global.Run(
|
||||
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user