mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK-Tile] Enable vectorized reads on all layouts & improve perf. (#1835)
* Refactor universal gemm policy.
* Adapt example to refactor changes.
* Introduce static encoding pattern
* Adding shuffled encoding patterns.
* Fix err in reverse tuple.
* Add transpose_tile2d
* Small refactoring + doc
* Enable reading on contiguous dimension in all layouts.
* Transpose A/B register tile if needed for comp v3 pipeline.
* Take contiguous dim size when calculating dram vector load size.
* A/B smem pack size taken from WarpGemm attributes
* Update B LDS layout and setup tile distribution pattern at class level.
* Fix static assert.
* Fix errors in examples.
* Formatting & fix IsTranspose
* Fix VectorSize & refactor.
* Add error loging messages.
* Fix VecLoadSize and TranspseC for mem pipeline.
* Update unit-tests & disable mem pipeline.
* Clang format
* Update include/ck_tile/core/tensor/tile_window.hpp
Co-authored-by: jakpiase <jakub.piasecki@amd.com>
* Fix compilation and reviewers comments.
* Refactor unit-test. Fallback to non-universal gemm.
Need to use GemmPipelineAGmemBGmemCRegV1 for now,
since GemmKernel is now supporting also non-K major vector reads.
---------
Co-authored-by: jakpiase <jakub.piasecki@amd.com>
[ROCm/composable_kernel commit: 39dc25a9b8]
This commit is contained in:
@@ -70,9 +70,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
@@ -103,4 +101,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
Reference in New Issue
Block a user