mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder * hacky poc * fix tile engine * kill the lambda * maybe fix test build * more fixes * clang-format * save temp * clang-format * mostly fix examples * clang-format * remove dead code * more cleanup * fix fmha bwd build (default epilogue set/add appears to be broken) * fix default epilogue tests but not correctness * clang-format * fix bquant * clang-format * cleanup dead code * rearrange make windows for readability * restore changes to IsSupportedArgument * fix smoke-builder * clang-format * fixup rename class * build fixes * clang-format * fix builder * fixup * remove set from builder tests * fix test * clang-format * re-refactor the kernels * clang-format * fix header license * remove memory operation from conv bwd test * clang-format * clang-format example,include * clang-format test * build fixes * clang-format * solve compilation error * fix the CI * solve compilation error * clang format * solve merge conflict * solve merge conflict * solve the gfx11 error * solve test error * moar build fixes * remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -617,6 +617,117 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindow(const OutDataType* a_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_m,
|
||||
const index_t i_k)
|
||||
{
|
||||
// Step 1: Create tensor view for A (Out tensor)
|
||||
const auto& a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_descs_m_k[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, i_k});
|
||||
|
||||
return a_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindow(const InDataType* b_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_n,
|
||||
const index_t i_k)
|
||||
{
|
||||
// Step 1: Create tensor view for B (Weight tensor)
|
||||
const auto& b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_descs_n_k[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
|
||||
return b_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Create D tensor block windows
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
// Step 1: Create tensor view for D
|
||||
const auto& d_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& d_pad_view =
|
||||
pad_tensor_view(d_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(d_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return ds_block_window;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeCBlockWindow(WeiDataType* c_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for C (Input tensor)
|
||||
const auto& c_tensor_view = make_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr, kargs.c_grid_descs_m_n[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return c_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
|
||||
{
|
||||
@@ -895,38 +1006,49 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const index_t block_idx_k,
|
||||
const index_t group_id)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window =
|
||||
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
* @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
@@ -951,23 +1073,19 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const index_t block_idx_k,
|
||||
const index_t group_id)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window =
|
||||
MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window =
|
||||
MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
num_loop,
|
||||
@@ -976,11 +1094,27 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
// Run Epilogue Pipeline with k_batch dispatch
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
@@ -1066,8 +1200,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
@@ -1086,8 +1219,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
|
||||
@@ -518,25 +518,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(__gfx11__)
|
||||
if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add)
|
||||
{
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
|
||||
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
|
||||
{
|
||||
@@ -704,29 +685,31 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
MakeCBlockWindow(WeiDataType* c_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr,
|
||||
kargs.a_grid_desc_k_m); // A: out
|
||||
}();
|
||||
const auto& c_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, kargs.c_grid_desc_m_n);
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr,
|
||||
kargs.b_grid_desc_k_n); // B: in
|
||||
}();
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
|
||||
kargs.c_grid_desc_m_n);
|
||||
}();
|
||||
return make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
@@ -741,30 +724,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = views.at(I2);
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
@@ -773,67 +733,58 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create views to the data that each workgroup will process.
|
||||
*
|
||||
* @param views padded views of A, B, D and C tensors
|
||||
* @param i_m block m-index
|
||||
* @param i_n block n-index
|
||||
* @param i_k block k-index
|
||||
*
|
||||
* @return tuple of tile windows for A, B, D and C tensors
|
||||
*/
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const index_t i_m,
|
||||
const index_t i_n,
|
||||
const index_t i_k)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_k, i_m});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
{block_idx_m, block_idx_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindow(const InDataType* b_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_k_n);
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
const auto& b_pad_view =
|
||||
pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_k, block_idx_n});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindow(const OutDataType* a_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_k_m);
|
||||
|
||||
const auto& a_pad_view =
|
||||
pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
|
||||
{block_idx_k, block_idx_m});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -859,28 +810,30 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
// Create block windows using helper methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -910,27 +863,33 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
// Create block windows using helper methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
return;
|
||||
#endif
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
@@ -960,12 +919,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
{
|
||||
CallExplicitGemm(kargs);
|
||||
@@ -1001,9 +954,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
@@ -1021,9 +972,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
|
||||
@@ -794,34 +794,53 @@ struct GroupedConvolutionForwardKernel
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
typename ADescType,
|
||||
typename BDescType,
|
||||
typename CDescType>
|
||||
template <typename ADescType>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
const ADescType& a_desc,
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc)
|
||||
MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
}();
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
|
||||
}();
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
|
||||
}();
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
|
||||
template <typename BDescType>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindow(const WeiDataType* b_ptr, const BDescType& b_desc, const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& b_tensor_view = make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_n, 0});
|
||||
}
|
||||
|
||||
template <typename CDescType>
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const CDescType& c_desc,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor views
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
@@ -836,30 +855,8 @@ struct GroupedConvolutionForwardKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = views.at(I2);
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
// Step 2: Create padded views
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
@@ -868,55 +865,38 @@ struct GroupedConvolutionForwardKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
// Step 3: Create tile windows
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
{block_idx_m, block_idx_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename CDescType>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindow(OutDataType* c_ptr,
|
||||
const CDescType& c_desc,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& c_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, c_desc);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -931,6 +911,7 @@ struct GroupedConvolutionForwardKernel
|
||||
* @param b_desc Weight tensor B descriptor
|
||||
* @param c_desc Output tensor C descriptor
|
||||
* @param gemm_k The GEMM K dimension
|
||||
* @param k_batch The K batch parameter for split-K
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
@@ -945,34 +926,41 @@ struct GroupedConvolutionForwardKernel
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t k_batch,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const CDElementwise& elfunc)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -990,6 +978,7 @@ struct GroupedConvolutionForwardKernel
|
||||
* @param b_desc Weight tensor B descriptor
|
||||
* @param c_desc Output tensor C descriptor
|
||||
* @param gemm_k The GEMM K dimension
|
||||
* @param k_batch The K batch parameter for split-K
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
@@ -1005,33 +994,41 @@ struct GroupedConvolutionForwardKernel
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t k_batch,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const CDElementwise& elfunc)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
// Create block windows using specialized methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, c_desc, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const
|
||||
@@ -1185,9 +1182,7 @@ struct GroupedConvolutionForwardKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
@@ -1200,6 +1195,7 @@ struct GroupedConvolutionForwardKernel
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
kargs.k_batch,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
@@ -1207,9 +1203,7 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
@@ -1221,6 +1215,7 @@ struct GroupedConvolutionForwardKernel
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
kargs.k_batch,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
|
||||
Reference in New Issue
Block a user