mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Initial code drop
This commit is contained in:
@@ -238,9 +238,9 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 8; //64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
@@ -248,11 +248,13 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8; //get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
|
||||
static constexpr ck_tile::index_t PingPongDim = 1; // 0 - Off, 1 - M, 2 - N and 3 - K
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -486,7 +488,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -60,7 +60,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
GemmConfig::Preshuffle,
|
||||
GemmConfig::PingPongDim>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
@@ -94,6 +95,24 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
/*
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
false,
|
||||
memory_operation>>;
|
||||
*/
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -124,7 +143,14 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
if constexpr(GemmConfig::PingPongDim == 0)
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::PingPongGridSize(args.M, args.N, args.K, args.k_batch);
|
||||
}
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
@@ -186,7 +212,21 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
//Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
if constexpr(GemmConfig::PingPongDim == 0)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -210,61 +250,18 @@ int run_gemm_example_prec_type(std::string a_layout,
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
|
||||
}
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,52 +272,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
|
||||
@@ -130,6 +130,11 @@ struct GemmKernel
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto PingPongGridSize(index_t, index_t N, index_t K, index_t KBatch) -> dim3
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(N, K), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
|
||||
@@ -264,6 +264,11 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE auto GetPingPongMLoops(index_t NumWavefronts) noexcept -> index_t
|
||||
{
|
||||
return integer_divide_ceil(M, MPerBlock * NumWavefronts);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
|
||||
*
|
||||
|
||||
@@ -284,6 +284,11 @@ struct UniversalGemmKernel
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto PingPongGridSize(index_t, index_t N, index_t K, index_t KBatch) -> dim3
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(N, K), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
{
|
||||
if(ck_tile::is_wave32())
|
||||
@@ -845,6 +850,90 @@ struct UniversalGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakePingPongGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& b_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
},
|
||||
number<NumBTensor{});
|
||||
|
||||
const auto& d_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, d_pad_view, e_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
@@ -934,6 +1023,79 @@ struct UniversalGemmKernel
|
||||
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakePingPongGemmTileWindows
|
||||
(const PadView& views, const index_t i_n, const index_t i_k, [[maybe_unused]] const index_t M, [[maybe_unused]] const index_t N, [[maybe_unused]] const index_t K)
|
||||
{
|
||||
const auto& as_pad_view = views.at(I0);
|
||||
const auto& bs_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& as_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
as_pad_view[i], make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}), {0, i_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
as_pad_view[i], make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}), {i_k, 0});
|
||||
}
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& bs_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bs_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, i_k});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
b_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
}
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor)
|
||||
{
|
||||
return make_tile_window(
|
||||
ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
|
||||
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
@@ -987,6 +1149,43 @@ struct UniversalGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static void PingPongGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views =
|
||||
MakePingPongGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
MakePingPongGemmTileWindows(gemm_pad_views, block_idx_n, block_idx_k, kargs.M, kargs.N, kargs.K);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(integer_divide_ceil(
|
||||
//kargs.M, TilePartitioner::MPerBlock * GemmPipeline::BlockGemmShape::NumWarps));
|
||||
kargs.M, TilePartitioner::MPerBlock));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
auto& e_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
const auto EpilogueFunc = [&](auto &window, auto& tile) {
|
||||
EpiloguePipeline{}.template operator()<decltype(window), decltype(tile)>(
|
||||
window, tile, smem_ptr_0);
|
||||
};
|
||||
|
||||
GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, e_block_window, num_loop, smem_ptr_0, EpilogueFunc);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
@@ -1045,9 +1244,9 @@ struct UniversalGemmKernel
|
||||
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
|
||||
{
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const auto [iN, iK] = TilePartitioner{kargs.N, kargs.K}.GetOutputTileIndex(blockId);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
const index_t i_k = __builtin_amdgcn_readfirstlane(iK * TilePartitioner::KPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
|
||||
@@ -1075,43 +1274,8 @@ struct UniversalGemmKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
PingPongGemm(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_n, i_k);
|
||||
}
|
||||
|
||||
// Persistent kernel entry point
|
||||
|
||||
@@ -49,6 +49,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t PingPongDim = Problem::PingPongDim;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
@@ -85,6 +86,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
|
||||
static constexpr index_t WaveStep = NumWarps * MPerBlock;
|
||||
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -116,6 +118,213 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
index_t PingPongDim,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BElementFunction,
|
||||
typename CDramBlockWindowTmp,
|
||||
typename EpilogueFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
[[maybe_unused]] CDramBlockWindowTmp& c_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
[[maybe_unused]] const EpilogueFunction& epilogue_func) const
|
||||
{
|
||||
//static_assert((MPerBlock * num_loop * NumWarps) == Problem::kM,
|
||||
// "Ping Pong Warps, Tile size and Block size for M dimension does not match.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
index_t warp_id = get_warp_id();
|
||||
index_t operation_id =
|
||||
__builtin_amdgcn_readfirstlane(get_warp_id()); // 0 - Memory read, 1 - block-gemm
|
||||
|
||||
auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(MPerBlock, 0);
|
||||
auto c_offset = (warp_id == 0) ? make_array(0, 0) : make_array(MPerBlock, 0);
|
||||
|
||||
auto tensor_views =
|
||||
Base::GetABLdsTensorViews(static_cast<void*>(static_cast<char*>(p_smem_0)));
|
||||
auto& a_lds_block = tensor_views.get(number<0>{});
|
||||
auto& b_lds_block = tensor_views.get(number<1>{});
|
||||
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
auto a_windows = Base::GetAWindows(
|
||||
a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset);
|
||||
auto& a_copy_dram_window = a_windows.get(number<0>{});
|
||||
auto& a_copy_lds_window = a_windows.get(number<1>{});
|
||||
auto& a_lds_window = a_windows.get(number<2>{});
|
||||
|
||||
auto b_windows = Base::GetBWindows(
|
||||
b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
auto& b_copy_dram_window = b_windows.get(number<0>{});
|
||||
auto& b_copy_lds_window = b_windows.get(number<1>{});
|
||||
auto& b_lds_window = b_windows.get(number<2>{});
|
||||
|
||||
auto epilogue_dram_window = make_tile_window(c_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
c_dram_block_window_tmp.get_window_origin() + c_offset);
|
||||
|
||||
// Add the offset which is warp specific so that subsequently we can increase it
|
||||
// with a fixed step size, which is also independent of the warp id.
|
||||
//c_dram_block_window_tmp += c_offset;
|
||||
//move_tile_window(c_dram_block_window_tmp, c_offset);
|
||||
|
||||
// DRAM window steps.
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(MPerBlock * NumWarps, 0)
|
||||
: make_array(0, MPerBlock * NumWarps);
|
||||
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, 0);
|
||||
|
||||
using CDramTileWindowStep = typename CDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr CDramTileWindowStep c_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock * NumWarps, 0) : make_array(0, KPerBlock * NumWarps);
|
||||
|
||||
constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
using AGemmTile = decltype(make_static_distributed_tensor<ADataType>(AGemmTileDistr));
|
||||
using BGemmTile = decltype(make_static_distributed_tensor<BDataType>(BGemmTileDistr));
|
||||
AGemmTile a_tile_0, a_tile_1;
|
||||
BGemmTile b_tile;
|
||||
|
||||
// Register tile for A and B.
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
ABlockTile a_global_load_tile;
|
||||
BBlockTile b_global_load_tile;
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile_0 = block_gemm.MakeCBlockTile();
|
||||
auto c_block_tile_1 = block_gemm.MakeCBlockTile();
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 1; }, c_block_tile_0);
|
||||
tile_elementwise_inout([](auto& c) { c = 2; }, c_block_tile_1);
|
||||
|
||||
auto BReadOps = [&](){
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
|
||||
}
|
||||
Base::LocalPrefetch(b_tile, b_lds_window);
|
||||
};
|
||||
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
// Memory read half here.
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
|
||||
}
|
||||
|
||||
if(idx == 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_0, a_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_1, a_lds_window);
|
||||
}
|
||||
};
|
||||
|
||||
auto ComputeStep = [&](auto idx) {
|
||||
if(idx == 0)
|
||||
{
|
||||
tile_elementwise_inout([](auto& c) { c = 1; }, c_block_tile_0);
|
||||
//block_gemm(c_block_tile_0, a_tile_0, b_tile);
|
||||
|
||||
epilogue_func(epilogue_dram_window, c_block_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([](auto& c) { c = 1; }, c_block_tile_1);
|
||||
//block_gemm(c_block_tile_1, a_tile_1, b_tile);
|
||||
|
||||
epilogue_func(epilogue_dram_window, c_block_tile_1);
|
||||
}
|
||||
};
|
||||
|
||||
// Read B block tile
|
||||
BReadOps();
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
MemoryOpsStep(warp_id);
|
||||
}
|
||||
|
||||
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop);
|
||||
while(num_compute_steps > 100)
|
||||
{
|
||||
block_sync_lds();
|
||||
operation_id = (operation_id + 1) % NumWaveGroups;
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
MemoryOpsStep(warp_id);
|
||||
//move_tile_window(c_dram_block_window_tmp, {WaveStep, 0});
|
||||
epilogue_dram_window = make_tile_window(epilogue_dram_window.get_bottom_tensor_view(),
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
epilogue_dram_window.get_window_origin() + c_dram_tile_window_step);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
|
||||
num_compute_steps -= 1;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
@@ -350,39 +559,49 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename CDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
typename BElementFunction,
|
||||
typename EpilogueFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const CDramBlockWindowTmp& c_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
void* p_smem_0,
|
||||
const EpilogueFunction& epilogue_func) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum, Problem::PingPongDim>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
c_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
p_smem_0,
|
||||
epilogue_func);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp, typename CDramBlockWindowTmp, typename EpilogueFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const CDramBlockWindowTmp& c_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
void* __restrict__ p_smem_0,
|
||||
const EpilogueFunction& epilogue_func) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum, Problem::PingPongDim>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
c_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
p_smem_0,
|
||||
epilogue_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -212,6 +212,8 @@ struct UniversalGemmPipelineProblem
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr index_t PingPongDim = Traits::PingPongDim;
|
||||
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
@@ -43,7 +43,8 @@ template <bool kPadM_,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false,
|
||||
index_t NumWaveGroups_ = 1,
|
||||
bool Preshuffle_ = 0>
|
||||
bool Preshuffle_ = 0,
|
||||
index_t PingPongDim = 0>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -61,6 +62,7 @@ struct TileGemmUniversalTraits
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
static constexpr bool Preshuffle = Preshuffle_;
|
||||
static constexpr index_t PingPongDim = PingPongDim_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
|
||||
Reference in New Issue
Block a user