mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
N dimension parallelism code drop
This commit is contained in:
@@ -180,7 +180,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
else
|
||||
{
|
||||
std::cout << "Ping pong....ON " << std::endl;
|
||||
grids = Kernel::PingPongGridSize(args.M, args.N, args.K, args.k_batch);
|
||||
grids = Kernel::PingPongGridSizeNParallel(args.M, args.N, args.K, args.k_batch);
|
||||
std::cout << "Arguments: { " << args.M << ", " << args.N << ", " << args.K << ", " << args.k_batch << " }" << std::endl;
|
||||
std::cout << "Grid size : {" << grids.x << ", " << grids.y << ", " << grids.z
|
||||
<< "}" << std::endl;
|
||||
|
||||
@@ -361,14 +361,17 @@ struct CShuffleEpilogue
|
||||
buffer_store_fence();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
if (execute_epilogue)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -135,6 +135,11 @@ struct GemmKernel
|
||||
return dim3(TilePartitioner::PingPongGridSize(N, K), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto PingPongGridSizeNParallel(index_t M, index_t, index_t K, index_t KBatch) -> dim3
|
||||
{
|
||||
return dim3(TilePartitioner::PingPongGridSizeNParallel(M, K), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
|
||||
@@ -265,6 +265,18 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
return GridDimX * GridDimY;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
PingPongGridSizeNParallel(index_t M, index_t K) noexcept(noexcept(MPerBlock != 0 && KPerBlock != 0)) -> index_t
|
||||
{
|
||||
const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
|
||||
const index_t GridDimY = integer_divide_ceil(K, KPerBlock);
|
||||
|
||||
std::cout << "PingPong Grid size, N_DIM_PARALLELISM M GRID SIZE : {" << GridDimX << ", " << GridDimY << "}" << std::endl;
|
||||
std::cout << "Arguments: { " << M << ", " << K << " }" << std::endl;
|
||||
std::cout << "Block size : {" << MPerBlock << ", " << KPerBlock << "}" << std::endl;
|
||||
return GridDimX * GridDimY;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate number of loop iterations over GEMM's K dimension.
|
||||
*
|
||||
|
||||
@@ -287,7 +287,13 @@ struct UniversalGemmKernel
|
||||
CK_TILE_HOST static auto PingPongGridSize(index_t, index_t N, index_t K, index_t KBatch) -> dim3
|
||||
{
|
||||
return dim3(TilePartitioner::PingPongGridSize(N, K), 1, KBatch);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
PingPongGridSizeNParallel(index_t M, index_t, index_t K, index_t KBatch) -> dim3
|
||||
{
|
||||
return dim3(TilePartitioner::PingPongGridSizeNParallel(M, K), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
{
|
||||
@@ -855,41 +861,45 @@ struct UniversalGemmKernel
|
||||
{
|
||||
const auto& a_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& b_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& d_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -910,29 +920,29 @@ struct UniversalGemmKernel
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, d_pad_view, e_pad_view);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
@@ -1024,68 +1034,77 @@ struct UniversalGemmKernel
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakePingPongGemmTileWindows
|
||||
(const PadView& views, const index_t i_n, const index_t i_k, [[maybe_unused]] const index_t M, [[maybe_unused]] const index_t N, [[maybe_unused]] const index_t K)
|
||||
CK_TILE_DEVICE static auto
|
||||
MakePingPongGemmTileWindowsMParallel(const PadView& views,
|
||||
const index_t i_n,
|
||||
const index_t i_k,
|
||||
[[maybe_unused]] const index_t M,
|
||||
[[maybe_unused]] const index_t N,
|
||||
[[maybe_unused]] const index_t K)
|
||||
{
|
||||
const auto& as_pad_view = views.at(I0);
|
||||
const auto& bs_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& as_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
as_pad_view[i], make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}), {0, i_k});
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{0, i_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
as_pad_view[i], make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}), {i_k, 0});
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_k, 0});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& bs_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bs_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, i_k});
|
||||
return make_tile_window(bs_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, i_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
bs_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
return make_tile_window(bs_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
}
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_n, i_k});
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_n, i_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
}
|
||||
},
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
@@ -1093,8 +1112,90 @@ struct UniversalGemmKernel
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
|
||||
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakePingPongGemmTileWindowsNParallel(const PadView& views,
|
||||
const index_t i_m,
|
||||
const index_t i_k,
|
||||
[[maybe_unused]] const index_t M,
|
||||
[[maybe_unused]] const index_t N,
|
||||
[[maybe_unused]] const index_t K)
|
||||
{
|
||||
const auto& as_pad_view = views.at(I0);
|
||||
const auto& bs_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& as_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, i_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(as_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_k, i_m});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& bs_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(bs_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{0, i_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(bs_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, 0});
|
||||
}
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_k, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, 0});
|
||||
|
||||
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
@@ -1149,43 +1250,42 @@ struct UniversalGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static void PingPongGemm(const std::array<const ADataType*, NumATensor>& a_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& b_ptr,
|
||||
const std::array<const void*, NumDTensor>& d_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_0,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
[[maybe_unused]] const index_t block_idx_n,
|
||||
[[maybe_unused]] const index_t block_idx_k)
|
||||
// PingPongGemmNDim(as_ptr, bs_ptr, kargs.ds_ptr, es_ptr, smem_ptr_0, smem_ptr_1,
|
||||
// smem_ptr_2, kargs, i_n, i_k);
|
||||
CK_TILE_DEVICE static void
|
||||
PingPongGemmNDim(const std::array<const ADataType*, NumATensor>& a_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& b_ptr,
|
||||
const std::array<const void*, NumDTensor>& d_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr_1,
|
||||
void* smem_ptr_2,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
[[maybe_unused]] const index_t block_idx_n,
|
||||
[[maybe_unused]] const index_t block_idx_k)
|
||||
{
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto kBlocks = __builtin_amdgcn_readfirstlane(integer_divide_ceil(
|
||||
kargs.K, TilePartitioner::KPerBlock));
|
||||
auto idx_n = __builtin_amdgcn_readfirstlane(blockId / kBlocks);
|
||||
auto idx_k = __builtin_amdgcn_readfirstlane(blockId % kBlocks);
|
||||
auto n_offset = __builtin_amdgcn_readfirstlane(idx_n * TilePartitioner::NPerBlock);
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
|
||||
const auto kBlocks = __builtin_amdgcn_readfirstlane(
|
||||
integer_divide_ceil(kargs.K, TilePartitioner::KPerBlock));
|
||||
|
||||
auto idx_m = __builtin_amdgcn_readfirstlane(blockId / kBlocks);
|
||||
auto idx_k = __builtin_amdgcn_readfirstlane(blockId % kBlocks);
|
||||
auto m_offset = __builtin_amdgcn_readfirstlane(idx_m * TilePartitioner::MPerBlock);
|
||||
auto k_offset = __builtin_amdgcn_readfirstlane(idx_k * TilePartitioner::KPerBlock);
|
||||
|
||||
//auto idx_k = __builtin_amdgcn_readfirstlane(blockId / kargs.N);
|
||||
//auto idx_n = __builtin_amdgcn_readfirstlane(blockId % TilePartitioner::NPerBlock);
|
||||
|
||||
//auto n_offset = __builtin_amdgcn_readfirstlane(idx_n * TilePartitioner::NPerBlock);
|
||||
//auto k_offset = __builtin_amdgcn_readfirstlane(idx_k * TilePartitioner::KPerBlock);
|
||||
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, d_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views =
|
||||
MakePingPongGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakePingPongGemmTileWindows(gemm_pad_views, n_offset, k_offset, kargs.M, kargs.N, kargs.K);
|
||||
const auto& gemm_pad_views = MakePingPongGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakePingPongGemmTileWindowsNParallel(
|
||||
gemm_pad_views, m_offset, k_offset, kargs.M, kargs.N, kargs.K);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(integer_divide_ceil(
|
||||
//kargs.M, TilePartitioner::MPerBlock * GemmPipeline::BlockGemmShape::NumWarps));
|
||||
kargs.M, TilePartitioner::MPerBlock));
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
integer_divide_ceil(kargs.N, TilePartitioner::NPerBlock));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -1193,21 +1293,89 @@ struct UniversalGemmKernel
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
auto& e_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
|
||||
const auto EpilogueFunc = [&](auto &out_window, auto& tile, auto &ds_window, auto execute_epilogue) {
|
||||
EpiloguePipeline{}.template operator()<decltype(out_window), decltype(tile), decltype(ds_window)>(
|
||||
out_window, tile, ds_window, smem_ptr_0, execute_epilogue);
|
||||
};
|
||||
|
||||
const auto EpilogueFunc =
|
||||
[&](auto& out_window, auto& tile, auto& ds_window, auto execute_epilogue) {
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(out_window), decltype(tile), decltype(ds_window)>(
|
||||
out_window, tile, ds_window, smem_ptr_2, execute_epilogue);
|
||||
};
|
||||
|
||||
GemmPipeline{}.template operator()(a_block_window[I0],
|
||||
b_block_window[I0],
|
||||
d_block_window,
|
||||
e_block_window,
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
EpilogueFunc);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static void
|
||||
PingPongGemmMDim(const std::array<const ADataType*, NumATensor>& a_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& b_ptr,
|
||||
const std::array<const void*, NumDTensor>& d_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_0,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
[[maybe_unused]] const index_t block_idx_n,
|
||||
[[maybe_unused]] const index_t block_idx_k)
|
||||
{
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto kBlocks = __builtin_amdgcn_readfirstlane(
|
||||
integer_divide_ceil(kargs.K, TilePartitioner::KPerBlock));
|
||||
auto idx_n = __builtin_amdgcn_readfirstlane(blockId / kBlocks);
|
||||
auto idx_k = __builtin_amdgcn_readfirstlane(blockId % kBlocks);
|
||||
auto n_offset = __builtin_amdgcn_readfirstlane(idx_n * TilePartitioner::NPerBlock);
|
||||
auto k_offset = __builtin_amdgcn_readfirstlane(idx_k * TilePartitioner::KPerBlock);
|
||||
|
||||
// auto idx_k = __builtin_amdgcn_readfirstlane(blockId / kargs.N);
|
||||
// auto idx_n = __builtin_amdgcn_readfirstlane(blockId % TilePartitioner::NPerBlock);
|
||||
|
||||
// auto n_offset = __builtin_amdgcn_readfirstlane(idx_n * TilePartitioner::NPerBlock);
|
||||
// auto k_offset = __builtin_amdgcn_readfirstlane(idx_k * TilePartitioner::KPerBlock);
|
||||
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, d_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakePingPongGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakePingPongGemmTileWindowsMParallel(
|
||||
gemm_pad_views, n_offset, k_offset, kargs.M, kargs.N, kargs.K);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(integer_divide_ceil(
|
||||
// kargs.M, TilePartitioner::MPerBlock * GemmPipeline::BlockGemmShape::NumWarps));
|
||||
kargs.M,
|
||||
TilePartitioner::MPerBlock));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
auto& e_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
const auto EpilogueFunc =
|
||||
[&](auto& out_window, auto& tile, auto& ds_window, auto execute_epilogue) {
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(out_window), decltype(tile), decltype(ds_window)>(
|
||||
out_window, tile, ds_window, smem_ptr_0, execute_epilogue);
|
||||
};
|
||||
|
||||
/*
|
||||
const auto EpilogueFunc = [&](auto &out_window, auto& tile) {
|
||||
EpiloguePipeline{}.template operator()<decltype(out_window), decltype(tile)>(
|
||||
out_window, tile);
|
||||
};
|
||||
};
|
||||
*/
|
||||
GemmPipeline{}.template operator()(
|
||||
a_block_window[I0], b_block_window[I0], d_block_window, e_block_window, num_loop, smem_ptr_0, EpilogueFunc);
|
||||
}
|
||||
GemmPipeline{}.template operator()(a_block_window[I0],
|
||||
b_block_window[I0],
|
||||
d_block_window,
|
||||
e_block_window,
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
EpilogueFunc);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
@@ -1296,9 +1464,23 @@ struct UniversalGemmKernel
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_2[GetSmemSize()];
|
||||
PingPongGemmNDim(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
es_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
smem_ptr_2,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_n,
|
||||
i_k);
|
||||
|
||||
PingPongGemm(
|
||||
as_ptr, bs_ptr, kargs.ds_ptr, es_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_n, i_k);
|
||||
// PingPongGemmMDim(
|
||||
// as_ptr, bs_ptr, kargs.ds_ptr, es_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_n,
|
||||
// i_k);
|
||||
}
|
||||
|
||||
// Persistent kernel entry point
|
||||
|
||||
@@ -124,6 +124,238 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
index_t PingPongDim,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BElementFunction,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename CDramBlockWindowTmp,
|
||||
typename EpilogueFunction>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
[[maybe_unused]] const BElementFunction& b_element_func,
|
||||
[[maybe_unused]] const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
[[maybe_unused]] CDramBlockWindowTmp& c_dram_block_window_tmp,
|
||||
[[maybe_unused]] index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
[[maybe_unused]] void* __restrict__ p_smem_1,
|
||||
[[maybe_unused]] const EpilogueFunction& epilogue_func
|
||||
) const
|
||||
{
|
||||
|
||||
[[maybe_unused]] constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
[[maybe_unused]] constexpr bool is_b_row_major =
|
||||
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
[[maybe_unused]] constexpr bool is_c_col_major =
|
||||
std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
static_assert(NumWaveGroups == 2);
|
||||
|
||||
index_t warp_id = get_warp_id();
|
||||
[[maybe_unused]] index_t operation_id = __builtin_amdgcn_readfirstlane((get_warp_id() + 1) % NumWaveGroups);
|
||||
|
||||
[[maybe_unused]] auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(NPerBlock, 0); // column major
|
||||
[[maybe_unused]] auto c_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, NPerBlock); // row major
|
||||
|
||||
[[maybe_unused]] auto tensor_views =
|
||||
Base::GetABLdsTensorViews(static_cast<void*>(static_cast<char*>(p_smem_0)));
|
||||
[[maybe_unused]] auto& a_lds_block = tensor_views.get(number<0>{});
|
||||
[[maybe_unused]] auto& b_lds_block = tensor_views.get(number<1>{});
|
||||
|
||||
[[maybe_unused]] constexpr auto a_lds_laod_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
[[maybe_unused]] constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
[[maybe_unused]] auto a_windows =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_laod_tile_distr);
|
||||
[[maybe_unused]] auto& a_copy_dram_window = a_windows.get(number<0>{});
|
||||
[[maybe_unused]] auto& a_copy_lds_window = a_windows.get(number<1>{});
|
||||
[[maybe_unused]] auto& a_lds_window = a_windows.get(number<2>{});
|
||||
|
||||
[[maybe_unused]] auto b_windows =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset);
|
||||
[[maybe_unused]] auto& b_copy_dram_window = b_windows.get(number<0>{});
|
||||
[[maybe_unused]] auto& b_copy_lds_window = b_windows.get(number<1>{});
|
||||
[[maybe_unused]] auto& b_lds_window = b_windows.get(number<2>{});
|
||||
|
||||
[[maybe_unused]] auto epilogue_dram_window =
|
||||
make_tile_window(c_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
c_dram_block_window_tmp.get_window_origin() + c_offset);
|
||||
|
||||
// DRAM window steps.
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
[[maybe_unused]] constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, 0); // A is constant.
|
||||
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
[[maybe_unused]] constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(0, NPerBlock * NumWarps) // (k, N)
|
||||
: make_array(NPerBlock * NumWarps, 0); // (N, K)
|
||||
|
||||
using CDramBlockWindowStep = typename CDramBlockWindowTmp::BottomTensorIndex;
|
||||
[[maybe_unused]] constexpr CDramBlockWindowStep c_dram_tile_window_step =
|
||||
is_c_col_major ? make_array(NPerBlock * NumWarps, 0) : make_array(0, NPerBlock * NumWarps);
|
||||
|
||||
[[maybe_unused]] constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
[[maybe_unused]] constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
using AGemmTile = decltype(make_static_distributed_tensor<ADataType>(AGemmTileDistr));
|
||||
using BGemmTile = decltype(make_static_distributed_tensor<BDataType>(BGemmTileDistr));
|
||||
|
||||
[[maybe_unused]] AGemmTile a_tile;
|
||||
[[maybe_unused]] BGemmTile b_tile_0, b_tile_1;
|
||||
|
||||
// Register tiles for A and B.
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
[[maybe_unused]] ABlockTile a_dram_tile;
|
||||
[[maybe_unused]] BBlockTile b_dram_tile;
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile_0 = block_gemm.MakeCBlockTile();
|
||||
//auto c_block_tile_1 = block_gemm.MakeCBlockTile();
|
||||
|
||||
[[maybe_unused]] auto ReadA = [&](){
|
||||
|
||||
Base::GlobalPrefetch(a_dram_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_dram_tile, a_element_func);
|
||||
Base::LocalPrefetch(a_tile, a_lds_window);
|
||||
//tile_elementwise_inout([](auto& c) { c = 5; }, a_tile);
|
||||
};
|
||||
|
||||
[[maybe_unused]] auto ReadB = [&](auto idx)
|
||||
{
|
||||
Base::GlobalPrefetch(b_dram_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_dram_tile, b_element_func);
|
||||
if (idx == 0)
|
||||
{
|
||||
Base::LocalPrefetch(b_tile_0, b_lds_window);
|
||||
//tile_elementwise_inout([](auto& c) { c = 1; }, b_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
//tile_elementwise_inout([](auto& c) { c = 2; }, b_tile_1);
|
||||
}
|
||||
};
|
||||
|
||||
[[maybe_unused]] auto ComputeStep = [&](auto idx){
|
||||
if (idx == 0)
|
||||
{
|
||||
c_block_tile_0 = block_gemm(a_tile, b_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
c_block_tile_0 = block_gemm(a_tile, b_tile_1);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
ReadA();
|
||||
if (warp_id == 0)
|
||||
{
|
||||
ReadB(warp_id);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 0)
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
ReadB(warp_id);
|
||||
}
|
||||
__syncthreads();
|
||||
epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (warp_id == 0));
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id == 1)
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
__syncthreads();
|
||||
epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (warp_id == 1));
|
||||
|
||||
|
||||
|
||||
if (warp_id == 1)
|
||||
{
|
||||
//tile_elementwise_inout([](auto& c) { c = 5; }, a_tile);
|
||||
ReadA();
|
||||
//tile_elementwise_inout([](auto& c) { c = 1; }, b_tile_1);
|
||||
ReadB(warp_id);
|
||||
ComputeStep(warp_id);
|
||||
|
||||
//store_tile(epilogue_dram_window, cast_tile<ADataType>(c_block_tile_0));
|
||||
epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (operation_id == 0));
|
||||
}
|
||||
*/
|
||||
|
||||
__syncthreads();
|
||||
// Read constant A.
|
||||
ReadA();
|
||||
//Read B
|
||||
if (operation_id == 0)
|
||||
{
|
||||
ReadB(warp_id);
|
||||
}
|
||||
|
||||
index_t num_steps = __builtin_amdgcn_readfirstlane(num_loop);
|
||||
while(num_steps > 1){
|
||||
block_sync_lds();
|
||||
operation_id = (operation_id + 1) % NumWaveGroups;
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
ReadB(warp_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
__syncthreads();
|
||||
num_steps -= 1;
|
||||
|
||||
epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (operation_id == 1));
|
||||
if (operation_id == 1)
|
||||
{
|
||||
move_tile_window(epilogue_dram_window, c_dram_tile_window_step);
|
||||
}
|
||||
}
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
|
||||
epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (operation_id == 0));
|
||||
if (operation_id == 0)
|
||||
{
|
||||
move_tile_window(epilogue_dram_window, c_dram_tile_window_step);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// M Dimension parallelism here.
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
index_t PingPongDim,
|
||||
@@ -578,6 +810,72 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
N Dimension parallelism here.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename CDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename EpilogueFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const CDramBlockWindowTmp& c_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_0,
|
||||
void* p_smem_1,
|
||||
const EpilogueFunction& epilogue_func) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}
|
||||
.template operator()<HasHotLoop, TailNum, Problem::PingPongDim>(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
d_dram_block_window_tmp,
|
||||
c_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1,
|
||||
epilogue_func);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename CDramBlockWindowTmp,
|
||||
typename EpilogueFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const CDramBlockWindowTmp& c_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1,
|
||||
const EpilogueFunction& epilogue_func) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}
|
||||
.template operator()<HasHotLoop, TailNum, Problem::PingPongDim>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
d_dram_block_window_tmp,
|
||||
c_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1,
|
||||
epilogue_func);
|
||||
}
|
||||
|
||||
/*
|
||||
// M dimensional parallelism
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
@@ -633,6 +931,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
p_smem_0,
|
||||
epilogue_func);
|
||||
}
|
||||
*/
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user