mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK-Tile] Enable vectorized reads on all layouts & improve perf. (#1835)
* Refactor universal gemm policy. * Adapt example to refactor changes. * Introduce static encoding pattern * Adding shuffled encoding patterns. * Fix err in reverse tuple. * Add transpose_tile2d * Small refactoring + doc * Enable reading on contiguous dimension in all layouts. * Transpose A/B register tile if needed for comp v3 pipeline. * Take contiguous dim size when calculating dram vector load size. * A/B smem pack size taken from WarpGemm attributes * Update B LDS layout and setup tile distribution pattern at class level. * Fix static assert. * Fix errors in examples. * Formatting & fix IsTranspose * Fix VectorSize & refactor. * Add error loging messages. * Fix VecLoadSize and TranspseC for mem pipeline. * Update unit-tests & disable mem pipeline. * Clang format * Update include/ck_tile/core/tensor/tile_window.hpp Co-authored-by: jakpiase <jakub.piasecki@amd.com> * Fix compilation and reviewers comments. * Refactor unit-test. Fallback to non-universal gemm. Need to use GemmPipelineAGmemBGmemCRegV1 for now, since GemmKernel is now supporting also non-K major vector reads. --------- Co-authored-by: jakpiase <jakub.piasecki@amd.com>
This commit is contained in:
@@ -8,7 +8,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -69,6 +68,7 @@ struct GemmKernel
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -168,6 +168,7 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.KBatch != 1)
|
||||
{
|
||||
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -176,10 +177,14 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
|
||||
{
|
||||
std::cerr << "Can't support K that is not a multiple of KPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::VectorSizeA != 0)
|
||||
{
|
||||
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -187,10 +192,14 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
std::cerr << "Can't support M that is not a multiple of MPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % GemmPipeline::VectorSizeA != 0)
|
||||
{
|
||||
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -199,10 +208,14 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
std::cerr << "Can't support N that is not a multiple of NPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % GemmPipeline::VectorSizeB != 0)
|
||||
{
|
||||
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -210,10 +223,14 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
|
||||
{
|
||||
std::cerr << "Can't support K that is not a multiple of KPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::VectorSizeB != 0)
|
||||
{
|
||||
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -222,10 +239,14 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
std::cerr << "Can't support N that is not a multiple of NPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % GemmPipeline::VectorSizeC != 0)
|
||||
{
|
||||
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -233,10 +254,14 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
std::cerr << "Can't support M that is not a multiple of MPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % GemmPipeline::VectorSizeC != 0)
|
||||
{
|
||||
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -250,6 +275,14 @@ struct GemmKernel
|
||||
const GemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
// const auto idxs = TilePartitioner{}();
|
||||
// const auto i_m = idxs.at(number<0>{});
|
||||
// const auto i_n = idxs.at(number<1>{});
|
||||
// // options
|
||||
// const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
// const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
// // Convert pointers to tensor views
|
||||
// auto a_tensor_view = [&]() {
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -264,9 +297,9 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(1, kargs.stride_A),
|
||||
number<1>{},
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::VectorSizeA>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
@@ -276,9 +309,9 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(1, kargs.stride_B),
|
||||
number<1>{},
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::VectorSizeB>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
@@ -292,6 +325,7 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -331,9 +365,9 @@ struct GemmKernel
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -349,12 +383,13 @@ struct GemmKernel
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadN, false>{});
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -380,20 +415,45 @@ struct GemmKernel
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& c_pad_view = views.at(I2);
|
||||
auto c_block_window = make_tile_window(
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
Reference in New Issue
Block a user