mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Ck tile gemm example (#1488)
* Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout
* Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future.
* Fix: Clang Format, API fixed from fmha
* fix with better naming convention
* revert back the pipeline code of fmha
* Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one.
* clang format with the reference_gemm file
* convert the clang format with the remod.py
* Changed the format and variable name of the kernel gemm_shape and partitioner
---------
Co-authored-by: thomasning <thomasning@banff-cyxtera-s70-4.ctr.dcgpu>
[ROCm/composable_kernel commit: caacd38830]
This commit is contained in:
2
example/ck_tile/03_gemm/CMakeLists.txt
Normal file
2
example/ck_tile/03_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
set(CMAKE_BUILD_TYPE Debug)
|
||||
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
|
||||
23
example/ck_tile/03_gemm/README.md
Normal file
23
example/ck_tile/03_gemm/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# GEMM Matrix Multiplication
|
||||
|
||||
This folder contains example for GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_gemm_basic -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_basic`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:3328)
|
||||
-n m dimension (default:4096)
|
||||
-k k dimension (default:64)
|
||||
-e epsilon (default:1e-5)
|
||||
-v cpu validation or not (default:1)
|
||||
-prec precision (default:fp16)
|
||||
```
|
||||
274
example/ck_tile/03_gemm/gemm_basic.cpp
Normal file
274
example/ck_tile/03_gemm/gemm_basic.cpp
Normal file
@@ -0,0 +1,274 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_basic.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("b", "1", "batch size")
|
||||
.insert("m", "1024", "m dimension")
|
||||
.insert("n", "2048", "n dimension")
|
||||
.insert("k", "64", "k dimension")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("e", "1e-5", "Absolute error tolerance")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename LayoutA, typename LayoutB, typename LayoutC>
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// ToDo: This will be modified by the codegen code later.
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// ===============================================
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
using PipelineProblem = ck_tile::
|
||||
BlockGemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, kPadA, kPadB, kPadC>;
|
||||
// The GemmPipeline should also come from the Codegen.
|
||||
using GemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel =
|
||||
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.p_a,
|
||||
args.p_b,
|
||||
args.p_c,
|
||||
args.epsilon,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
args.stride_B,
|
||||
args.stride_C);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename DataType, typename LayoutA, typename LayoutB, typename LayoutC>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_buf,
|
||||
ck_tile::DeviceMem& b_buf,
|
||||
ck_tile::DeviceMem& c_buf,
|
||||
const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type != DataTypeTraits<DataType>::name)
|
||||
{
|
||||
std::cerr << "Data type mismatch: expected " << DataTypeTraits<DataType>::name << ", got "
|
||||
<< data_type << std::endl;
|
||||
return -1; // Or handle the error appropriately
|
||||
}
|
||||
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
ck_tile::index_t batch_size = arg_parser.get_int("b");
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
|
||||
|
||||
gemm_basic_args args;
|
||||
args.p_a = a_buf.GetDeviceBuffer();
|
||||
args.p_b = b_buf.GetDeviceBuffer();
|
||||
args.p_c = c_buf.GetDeviceBuffer();
|
||||
args.epsilon = epsilon;
|
||||
args.kbatch = batch_size;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
|
||||
// Only set stride_M and stride_N if they are non-zero and not equal to K.
|
||||
if(stride_a != 0)
|
||||
{
|
||||
args.stride_A = stride_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.stride_A = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutA, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return M;
|
||||
}
|
||||
else
|
||||
{
|
||||
return K;
|
||||
}
|
||||
}();
|
||||
}
|
||||
|
||||
if(stride_b != 0)
|
||||
{
|
||||
args.stride_B = stride_b;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.stride_B = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutB, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return N;
|
||||
}
|
||||
else
|
||||
{
|
||||
return K;
|
||||
}
|
||||
}();
|
||||
}
|
||||
|
||||
if(stride_c != 0)
|
||||
{
|
||||
args.stride_C = stride_c;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.stride_C = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return M;
|
||||
}
|
||||
else
|
||||
{
|
||||
return N;
|
||||
}
|
||||
}();
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
gemm_calc<LayoutA, LayoutB, LayoutC>(args, ck_tile::stream_config{nullptr, true});
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "The overall perfomance of the GEMM with "
|
||||
<< "[" << data_type << "]"
|
||||
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K
|
||||
<< "is: \n";
|
||||
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n"
|
||||
<< std::flush;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
|
||||
using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using matrix_b_layout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
// host verify
|
||||
std::vector<int> a_dimensions =
|
||||
(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
? std::vector<int>{M, K}
|
||||
: std::vector<int>{K, M};
|
||||
std::vector<int> b_dimensions =
|
||||
(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
? std::vector<int>{N, K}
|
||||
: std::vector<int>{K, N};
|
||||
std::vector<int> c_dimensions =
|
||||
(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
? std::vector<int>{M, N}
|
||||
: std::vector<int>{N, M};
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(a_dimensions);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_dimensions);
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_dimensions);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_dimensions);
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.data());
|
||||
b_buf.ToDevice(b_host.data());
|
||||
|
||||
invoke_gemm<ck_tile::half_t, matrix_a_layout, matrix_b_layout, matrix_c_layout>(
|
||||
a_buf, b_buf, c_buf, arg_parser);
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_bool("v"))
|
||||
{
|
||||
// ToDo: Will Add the Element Op (bias) verification in the future.
|
||||
ck_tile::reference_gemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
matrix_a_layout,
|
||||
matrix_b_layout,
|
||||
matrix_c_layout>(a_host, b_host, c_host_ref);
|
||||
|
||||
c_buf.FromDevice(c_host_dev.data());
|
||||
|
||||
pass = ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
|
||||
std::cout << "The veification result is:" << (pass ? "correct" : "fail") << std::flush;
|
||||
}
|
||||
|
||||
std::cout << std::endl << std::flush;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
71
example/ck_tile/03_gemm/gemm_basic.hpp
Normal file
71
example/ck_tile/03_gemm/gemm_basic.hpp
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t; // type convert
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
|
||||
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = Types::ADataType;
|
||||
using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
|
||||
struct gemm_basic_args
|
||||
{
|
||||
const void* p_a;
|
||||
const void* p_b;
|
||||
void* p_c;
|
||||
float epsilon;
|
||||
ck_tile::index_t kbatch;
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
ck_tile::index_t K;
|
||||
ck_tile::index_t stride_A;
|
||||
ck_tile::index_t stride_B;
|
||||
ck_tile::index_t stride_C;
|
||||
};
|
||||
|
||||
// host API
|
||||
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s);
|
||||
@@ -4,3 +4,4 @@ include_directories(AFTER
|
||||
|
||||
add_subdirectory(01_fmha)
|
||||
add_subdirectory(02_layernorm2d)
|
||||
add_subdirectory(03_gemm)
|
||||
|
||||
Reference in New Issue
Block a user