mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK-Tile] Universal gemm memory bound pipeline (#1558)
* CK-Tile GEMM with memory bound pipeline. * Memory bound gemm pipeline. * Fix not closed namespace. * Block gemm mem pipeline draft. * Do not use ck_tile:: within ck_tile namespace. * Refactoring & Move Layout info to pipeline problem. * Get hot loop and TailNum information before lunching kernel. * Fixes in pipeline. * Add comment to load_tile_raw and change variable naming style. * Few small changes & formatting. * Do not use macro. * Add gtests. * Use AccDataType for Output of MFMA instruction. * Formatting. * Refactor gemm examples. * Switch over to current block gemm. * Use currently available pipeline policy. * Refactoring and review comment.s * Fixes after merge. * Add missing include. * Add load tile overload which accepts output tensor as parameter. * This give 8% perf boost at the cost of using more registers. * Rename example. * Small changes. * Fix compilation err and lower K. * Support different layouts for A/B * Fix vector size for different layouts. * Rename Alignment into VectorSize * Unblock tests.
This commit is contained in:
@@ -4,12 +4,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
@@ -20,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t; // type convert
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
@@ -58,7 +56,6 @@ struct gemm_basic_args
|
||||
const void* p_a;
|
||||
const void* p_b;
|
||||
void* p_c;
|
||||
float epsilon;
|
||||
ck_tile::index_t kbatch;
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
@@ -68,5 +65,28 @@ struct gemm_basic_args
|
||||
ck_tile::index_t stride_C;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("b", "1", "batch size")
|
||||
.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "R", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s);
|
||||
|
||||
Reference in New Issue
Block a user