mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Ck tile batched gemm example (#1615)
* [CK Tile] Batched GEMM Example * [CK Tile] Batched GEMM Example - minor refactor * [CK Tile] Batched GEMM Example - README update * [CK Tile] Batched Gemm Example - review changes - Added tensor data layours as input parameters - Changed structure of Host and Kernel args - Removed bug with invalid vector read on non-contiguous memory * [CK Tile] Batched Gemm Example - remove comment * [CK Tile] Batched Gemm Example - Add GTests part1 * [CK Tile] Batched Gemm Example - GTests part2 + review changes * [CK TILE] Batched GEMM post merge fixes * [CK Tile] Batched GEMM Example - fix pad views
This commit is contained in:
1
example/ck_tile/16_batched_gemm/CMakeLists.txt
Normal file
1
example/ck_tile/16_batched_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp)
|
||||
37
example/ck_tile/16_batched_gemm/README.md
Normal file
37
example/ck_tile/16_batched_gemm/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Batched GEMM
|
||||
|
||||
This folder contains example for batched GEMM using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_batched_gemm -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_batched_gemm`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:256)
|
||||
-n n dimension (default:128)
|
||||
-k k dimension (default:128)
|
||||
-a_layout A tensor data layout (default:R) (R for Row, C for Col)
|
||||
-b_layout B tensor data layout (default:R) (R for Row, C for Col)
|
||||
-c_layout C tensor data layout (default:R) (R for Row, C for Col)
|
||||
-stride_a Tensor A stride (default:128)
|
||||
-stride_b Tensor B stride (default:128)
|
||||
-stride_c Tensor C stride (default:128)
|
||||
-batch_stride_a Batch A stride (default:32768)
|
||||
-batch_stride_b Batch B stride (default:16384)
|
||||
-batch_stride_c Batch C stride (default:32768)
|
||||
-batch_count Batch count (default:16)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
```
|
||||
103
example/ck_tile/16_batched_gemm/batched_gemm.cpp
Normal file
103
example/ck_tile/16_batched_gemm/batched_gemm.cpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "batched_gemm.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
constexpr bool kTilePermute = false;
|
||||
// The rank and permutation will also be generate out by the CodeGen part.
|
||||
constexpr ck_tile::index_t kOutputRank = 2;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
// Whether doing the CShuffle (transpose before the global memory), depending on the output
|
||||
// layout.
|
||||
constexpr bool CShuffleEpilogue =
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
CShuffleEpilogue,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
kPadM,
|
||||
kPadN,
|
||||
kTilePermute,
|
||||
kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::kM,
|
||||
TilePartitioner::kN>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
63
example/ck_tile/16_batched_gemm/batched_gemm.hpp
Normal file
63
example/ck_tile/16_batched_gemm/batched_gemm.hpp
Normal file
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct BatchedGemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct BatchedGemmTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
using Types = BatchedGemmTypeConfig<ck_tile::half_t>;
|
||||
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = Types::ADataType;
|
||||
using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
|
||||
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "128", "n dimension")
|
||||
.insert("k", "128", "k dimension")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "R", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("batch_stride_a", "32768", "Batch A stride")
|
||||
.insert("batch_stride_b", "16384", "Batch B stride")
|
||||
.insert("batch_stride_c", "32768", "Batch C stride")
|
||||
.insert("batch_count", "16", "Batch count")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s);
|
||||
253
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
Normal file
253
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
Normal file
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t batch_stride_A,
|
||||
ck_tile::index_t batch_stride_B,
|
||||
ck_tile::index_t batch_stride_C,
|
||||
ck_tile::index_t batch_count,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
batched_gemm_kargs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.batch_stride_A = batch_stride_A;
|
||||
args.batch_stride_B = batch_stride_B;
|
||||
args.batch_stride_C = batch_stride_C;
|
||||
args.batch_count = batch_count;
|
||||
|
||||
float ave_time = batched_gemm<ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Batched Gemm"};
|
||||
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * batch_count * M * K +
|
||||
sizeof(BDataType) * batch_count * N * K +
|
||||
sizeof(CDataType) * batch_count * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B
|
||||
<< " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_batched_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
|
||||
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
|
||||
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
|
||||
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
|
||||
{batch_stride, stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
|
||||
{batch_stride, 1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
|
||||
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
|
||||
stride_C = f_get_default_stride(M, N, stride_C, c_layout);
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_batched_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
const auto b_n_k = b_k_n.transpose({0, 2, 1});
|
||||
|
||||
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_n_k, c_m_n_host_ref);
|
||||
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
|
||||
|
||||
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ck_tile::reference_batched_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_gpu_buf_ref,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
|
||||
|
||||
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int run_batched_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
|
||||
// work else if(a_layout == "C" && b_layout == "C")
|
||||
// {
|
||||
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(a_layout == "C" && b_layout == "R")
|
||||
// {
|
||||
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
@@ -15,4 +15,4 @@ add_subdirectory(12_smoothquant)
|
||||
add_subdirectory(13_moe_sorting)
|
||||
add_subdirectory(14_moe_smoothquant)
|
||||
add_subdirectory(15_fused_moe)
|
||||
|
||||
add_subdirectory(16_batched_gemm)
|
||||
|
||||
Reference in New Issue
Block a user