Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8_uc

This commit is contained in:
aska-0096
2025-02-08 02:36:52 +00:00
32 changed files with 1163 additions and 217 deletions

View File

@@ -14,6 +14,7 @@ trigger:
branches:
include:
- develop
- amd-develop
paths:
exclude:
- .github

View File

@@ -13,7 +13,7 @@
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
@@ -315,40 +315,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "Computing GEMM on host..." << std::endl;
}
Tensor<CDataType> c({M, N});
Tensor<float> a({M, K});
Tensor<float> b({K, N});
for(int m = 0; m < M; m++)
{
for(int k = 0; k < K; k++)
{
a(m, k) = ck::type_convert<float>(a_m_k(m, k)) *
ck::type_convert<float>(a_m_k_scale(m, k / Scale_Block_K));
}
}
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
b(k, n) = ck::type_convert<float>(b_k_n(k, n)) *
ck::type_convert<float>(b_k_n_scale(k / Scale_Block_K, n));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<float,
float,
CShuffleDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<ADataType,
BDataType,
CDataType,
AccDataType,
float,
PassThrough,
PassThrough,
PassThrough,
float,
float>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a, b, c, PassThrough{}, PassThrough{}, PassThrough{});
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k_scale,
b_k_n,
b_k_n_scale,
c_m_n_host_result,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
@@ -366,8 +353,9 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl;
}
res_verified = res_verified &&
ck::utils::check_err(c_m_n_device_result, c, "Error: Incorrect results!");
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!");
if(config.verbosity > 0 && res_verified)
std::cout << "Done." << std::endl;

View File

@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
@@ -25,7 +31,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr ck_tile::index_t K_Warp_Tile = 16;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -99,12 +105,32 @@ int run_gemm_example(int argc, char* argv[])
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{

View File

@@ -18,7 +18,7 @@
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
@@ -43,6 +43,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
@@ -64,13 +91,23 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16";
};
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
auto create_args(int argc, char* argv[])
{

View File

@@ -9,6 +9,7 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
@@ -29,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType,
typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
@@ -55,7 +57,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
@@ -66,13 +69,19 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name
<< " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
@@ -83,6 +92,11 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
@@ -119,7 +133,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
invoke_gemm<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
@@ -145,7 +160,8 @@ int run_gemm_example_with_layouts(int argc,
a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
@@ -202,7 +218,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",

View File

@@ -2,7 +2,8 @@
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=1
for b_matrix_layout in "R" "C"; do
for b_matrix_layout in "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do

View File

@@ -0,0 +1,14 @@
#!/bin/sh
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done

View File

@@ -2,10 +2,10 @@
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "R" "C"; do
for m in "64" "512" "1024" "2048"; do
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done

View File

@@ -0,0 +1,13 @@
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done

View File

@@ -0,0 +1,13 @@
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done

View File

@@ -0,0 +1,13 @@
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=1
for b_matrix_layout in "C"; do
for m in "512" "1024" "2048" "4096"; do
for n in "512" "1024" "2048"; do
for k in "512" "1024" "2048"; do
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done

View File

@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() {
for batch in 1 2; do
for m in 128 1024; do
for n in 128 2048; do
for k in 32 64; do
run_tests() {
for m in 128 1024; do
for n in 128 2048; do
for k in 64 128; do
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
@@ -30,6 +28,9 @@ run_fp16_tests() {
set -x
run_fp16_tests
run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
set +x

View File

@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() {
for batch in 1 2; do
for m in 128 1024; do
for n in 128 2048; do
for k in 32 64; do
run_tests() {
for m in 512 1024; do
for n in 512 2048; do
for k in 512 1024; do
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
@@ -30,6 +28,9 @@ run_fp16_tests() {
set -x
run_fp16_tests
run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
set +x

View File

@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
@@ -33,7 +39,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
@@ -243,24 +249,101 @@ int run_gemm_example(int argc, char* argv[])
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
if(data_type == "fp16")
{
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf16")
{
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "fp8")
{
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Row{}, Row{});
}
else if(data_type == "bf8")
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{

View File

@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
@@ -745,7 +746,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "Kpack: "
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
// clang-format on
return str.str();

View File

@@ -103,14 +103,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n",
"%d, %d\n C MFMA inst: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d\n",
A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num);
C_MFMA_Inst_Num,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,
BLDSWriteWidth,
ABufferLoadWidth,
BBufferLoadWidth);
}
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
@@ -8,16 +8,75 @@
namespace ck_tile {
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b)
template <typename T, typename ComputeType>
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
{
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b));
return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
}
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
{
bf16x2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]);
rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add<bf16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
{
bf16x4_t rtn;
rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add<bf16_t, float>(a[1], b[1]);
rtn[2] = add<bf16_t, float>(a[2], b[2]);
rtn[3] = add<bf16_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
{
fp8x8_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
rtn[4] = add<fp8_t, float>(a[4], b[4]);
rtn[5] = add<fp8_t, float>(a[5], b[5]);
rtn[6] = add<fp8_t, float>(a[6], b[6]);
rtn[7] = add<fp8_t, float>(a[7], b[7]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
{
bf8x4_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
{
bf8x8_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
rtn[4] = add<bf8_t, float>(a[4], b[4]);
rtn[5] = add<bf8_t, float>(a[5], b[5]);
rtn[6] = add<bf8_t, float>(a[6], b[6]);
rtn[7] = add<bf8_t, float>(a[7], b[7]);
return rtn;
}
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
} while(cur_v.u32 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union U64BF164_ADDR
{
uint64_t* u64_a;
bf16x4_t* bf164_a;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union U64BF164
{
uint64_t u64;
bf16x4_t bf164;
};
U64BF164_ADDR addr;
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164 cur_v;
cur_v.u64 = *addr.u64_a;
U64BF164 new_v_union;
uint64_t old_v, new_v;
do
{
// old 64 bits
old_v = cur_v.u64;
// Add elementwise in bf16
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
new_v = new_v_union.u64;
// Attempt the 64-bit CAS
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
{
union U32FP84_ADDR
{
uint32_t* u32_a;
fp8x4_t* fp84_a;
};
union U32FP84
{
uint32_t u32;
fp8x4_t fp84;
};
U32FP84_ADDR dword_addr;
U32FP84 cur_v;
U32FP84 new_;
uint32_t old_v, new_v;
dword_addr.fp84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
{
union U32BF84_ADDR
{
uint32_t* u32_a;
bf8x4_t* bf84_a;
};
union U32BF84
{
uint32_t u32;
bf8x4_t bf84;
};
U32BF84_ADDR dword_addr;
U32BF84 cur_v;
U32BF84 new_;
uint32_t old_v, new_v;
dword_addr.bf84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
//
// Atomic add for fp8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union U64FP88_ADDR
{
uint64_t* u64_a; // pointer to 64-bit integer
fp8x8_t* fp88_a; // pointer to fp8x8_t
};
union U64FP88
{
uint64_t u64;
fp8x8_t fp88;
};
U64FP88_ADDR dword_addr;
U64FP88 cur_v;
U64FP88 new_v_union;
uint64_t old_v, new_v;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr.fp88_a = p_dst;
// Initial read of 64 bits from memory
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
new_v = new_v_union.u64;
// Attempt 64-bit CAS
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
//
// Atomic add for bf8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
{
union U64BF88_ADDR
{
uint64_t* u64_a;
bf8x8_t* bf88_a;
};
union U64BF88
{
uint64_t u64;
bf8x8_t bf88;
};
U64BF88_ADDR dword_addr;
U64BF88 cur_v;
U64BF88 new_v_union;
uint64_t old_v, new_v;
dword_addr.bf88_a = p_dst;
// Read the original 64 bits
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
new_v = new_v_union.u64;
// 64-bit CAS loop
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)),
"wrong! not implemented");
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
"The granularity of the thread buffer is unsupported on the hardware!");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
}
else if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1,
x.template get_as<bf16x2_t>()[I1]);
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
}
else if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
x.template get_as<bf16x4_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, bf8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
}
}
}

View File

@@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
@@ -34,4 +35,3 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"

View File

@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double get_relative_threshold(const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}

View File

@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
ck_tile::type_convert<AccDataType>(B[b_index]);
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
C[c_index] = ck_tile::type_convert<CDataType>(acc);
}
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -77,6 +77,7 @@ struct CShuffleEpilogue
*
* @return The vector store size for C tensor.
*/
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
@@ -142,7 +143,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC(),
GetVectorSizeC<ODataType>(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();

View File

@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static constexpr index_t KPack = 8;
static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack;
};

View File

@@ -159,7 +159,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
@@ -240,7 +240,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false;
@@ -255,7 +255,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false;
@@ -321,7 +321,7 @@ struct GemmKernel
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<1>{});
}
else
@@ -519,7 +519,7 @@ struct GemmKernel
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(

View File

@@ -3,6 +3,9 @@
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /

View File

@@ -90,7 +90,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
@@ -165,11 +165,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
@@ -213,25 +224,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
Base::GlobalPrefetch(
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
}
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
});
// main body
@@ -247,19 +292,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds();
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(
a_shuffle_tmp,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(
b_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
}
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window);
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window);
b_copy_dram_window,
b_dram_tile_window_step);
});
i += PrefetchStages;
@@ -275,12 +346,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds();
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
}
});
block_sync_lds();
@@ -352,11 +443,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
@@ -400,25 +502,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
Base::GlobalPrefetch(
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
}
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
});
// main body
@@ -432,19 +567,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(
a_shuffle_tmp,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(
b_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
}
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window);
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window);
b_copy_dram_window,
b_dram_tile_window_step);
});
i += PrefetchStages;
@@ -457,12 +618,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
}
});
block_sync_lds();

View File

@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
@@ -519,7 +518,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();

View File

@@ -0,0 +1,178 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ScaleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMXGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<ScaleDataType>& a_m_kblock_scales,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& b_kblock_n_scales,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
a_m_kblock_scales_{a_m_kblock_scales},
b_k_n_{b_k_n},
b_kblock_n_scales_{b_kblock_n_scales},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<ScaleDataType>& a_m_kblock_scales_;
const Tensor<BDataType>& b_k_n_;
const Tensor<ScaleDataType>& b_kblock_n_scales_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceMXGemm::Argument;
float Run(const Argument& arg)
{
using GemmInstance = ck::tensor_operation::host::ReferenceGemm<ComputeTypeA,
ComputeTypeB,
CDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputeTypeA,
ComputeTypeB>;
Tensor<ComputeTypeA> a_m_k_scaled(arg.a_m_k_.mDesc);
Tensor<ComputeTypeB> b_k_n_scaled(arg.b_k_n_.mDesc);
const auto M = arg.a_m_k_.mDesc.GetLengths()[0];
const auto N = arg.b_k_n_.mDesc.GetLengths()[1];
const auto K = arg.a_m_k_.mDesc.GetLengths()[1];
const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
for(size_t m = 0; m < M; m++)
{
for(size_t k = 0; k < K; k++)
{
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
}
for(size_t n = 0; n < N; n++)
{
for(size_t k = 0; k < K; k++)
{
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
}
auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k_scaled,
b_k_n_scaled,
arg.c_m_n_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ref_invoker.Run(ref_argument);
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<ScaleDataType>& a_m_kblock_scales,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& b_kblock_n_scales,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k,
a_m_kblock_scales,
b_k_n,
b_kblock_n_scales,
c_m_n,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceMXGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck

View File

@@ -72,7 +72,7 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple <
using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
@@ -86,16 +86,34 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 8, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 128, 8, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 8, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 256, 8, 8, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 4, 4, 32, 32, 2, 1, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 2, 2, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 4, 4, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 2, 2, 32, 32, 2, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 2, 2, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
#endif // defined(CK_USE_AMD_MFMA_GFX950)
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
@@ -14,28 +14,26 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>;
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>;
// clang-format on