mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#4302 (commit e62bd8a)
[CK_TILE] add tf32 support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in CK_TILE on gfx942 and gfx950. ## Checklist Please put an into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run on all changed files - [ ] Any dependent changes have been merged ## Discussion
This commit is contained in:
committed by
assistant-librarian[bot]
parent
652d3456ca
commit
d460ab35b6
@@ -41,6 +41,17 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
else if(data_type == "tf32")
|
||||
{
|
||||
// Pass tf32_t as A/B types - epilogue auto-detects and maps to float for data operations
|
||||
return run_gemm_example_prec_type<GemmConfig,
|
||||
Invoker,
|
||||
ck_tile::tf32_t,
|
||||
ck_tile::tf32_t,
|
||||
float>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
#endif
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig,
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
struct BasicInvoker
|
||||
{
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
@@ -19,14 +19,30 @@ struct BasicInvoker
|
||||
typename CDEElementWise>
|
||||
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// ADataTypeCompute: compute type (tf32_t for TF32 mode, used for warp gemm selection)
|
||||
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
|
||||
using ADataTypeCompute = ADataType_;
|
||||
using BDataTypeCompute = BDataType_;
|
||||
using ADataTypeBuf = ck_tile::if_select_t<ADataType_, ck_tile::tf32_t, float, ADataType_>;
|
||||
using BDataTypeBuf = ck_tile::if_select_t<BDataType_, ck_tile::tf32_t, float, BDataType_>;
|
||||
|
||||
if constexpr(std::is_same_v<ADataTypeCompute, ck_tile::tf32_t>)
|
||||
{
|
||||
static_assert(std::is_same_v<ADataTypeCompute, BDataTypeCompute>,
|
||||
"ADataTypeCompute and BDataTypeCompute must be the same");
|
||||
}
|
||||
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
|
||||
}
|
||||
|
||||
constexpr bool is_fp32_input = std::is_same_v<ADataTypeBuf, float>;
|
||||
constexpr bool is_tf32_compute = std::is_same_v<ADataTypeCompute, ck_tile::tf32_t>;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256;
|
||||
constexpr ck_tile::index_t N_Tile = is_fp32_input ? 128 : 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
@@ -38,12 +54,14 @@ struct BasicInvoker
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#else
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
// gfx950: fp32 uses 16x16x16 tile (native MFMA)
|
||||
// tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation)
|
||||
constexpr ck_tile::index_t M_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
|
||||
constexpr ck_tile::index_t N_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#endif
|
||||
|
||||
@@ -61,17 +79,21 @@ struct BasicInvoker
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataTypeBuf,
|
||||
BDataTypeBuf,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ADataTypeCompute>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataTypeCompute,
|
||||
BDataTypeCompute,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
@@ -112,7 +134,7 @@ struct BasicInvoker
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataTypeBuf, BDataTypeBuf>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
@@ -125,16 +147,21 @@ struct BasicInvoker
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
ck_tile::HostTensor<ADataTypeBuf> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
ck_tile::HostTensor<BDataTypeBuf> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataTypeBuf, BDataTypeBuf>>(
|
||||
kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
|
||||
@@ -35,6 +35,10 @@ struct GemmConfigBase
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
// Type trait for tf32 storage type (tf32 uses float for memory layout calculations)
|
||||
template <typename T>
|
||||
using prec_storage_type = ck_tile::if_select_t<T, ck_tile::tf32_t, float, T>;
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
{
|
||||
@@ -81,7 +85,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(prec_storage_type<PrecType>);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
@@ -121,7 +125,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type<PrecType>);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -293,7 +297,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type<PrecType>);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
@@ -302,7 +306,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile<prec_storage_type<PrecType>, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
@@ -324,6 +328,15 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<Pre
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::tf32_t, ck_tile::tf32_t, float>
|
||||
{
|
||||
using ADataType = float;
|
||||
using BDataType = float;
|
||||
using AccDataType = float;
|
||||
using CDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
@@ -486,7 +499,7 @@ inline auto create_args()
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t/tf32 (tf32 only on gfx950)")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -30,6 +30,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
@@ -205,11 +206,13 @@ std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_ge
|
||||
return std::make_tuple(M, N, K);
|
||||
}
|
||||
|
||||
// ADataType_ and BDataType_ are original types (e.g., tf32_t for TF32 mode)
|
||||
// They are passed through invoke_gemm to invoker for tf32 auto-detection
|
||||
template <typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ADataType_,
|
||||
typename BDataType_ = ADataType_,
|
||||
typename CDataType_ = ADataType_,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
@@ -218,7 +221,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
// ADataTypeCompute: compute type (tf32_t for TF32 mode, used for warp gemm selection)
|
||||
// ADataTypeBuf: buffer/storage type (fp32 when tf32, from TypeConfig)
|
||||
using ADataTypeCompute = ADataType_;
|
||||
using BDataTypeCompute = BDataType_;
|
||||
|
||||
// Use GemmTypeConfig to get actual data types for tensor operations
|
||||
// This handles tf32 -> float mapping for host tensors and device buffers
|
||||
using TypeConfig = GemmTypeConfig<ADataType_, BDataType_, CDataType_>;
|
||||
using ADataTypeBuf = typename TypeConfig::ADataType;
|
||||
using BDataTypeBuf = typename TypeConfig::BDataType;
|
||||
using CDataType = typename TypeConfig::CDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
@@ -242,27 +256,27 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::HostTensor<ADataTypeBuf> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::HostTensor<BDataTypeBuf> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataTypeBuf>{-2.f, 2.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataTypeBuf>{-2.f, 2.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
ck_tile::FillMonotonicSeq<ADataTypeBuf>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataTypeBuf>{}(b_k_n);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataTypeBuf>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataTypeBuf>{1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -274,7 +288,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
{
|
||||
if constexpr(GemmConfig::UseStructuredSparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
ck_tile::AdjustToStructuredSparsity<ADataTypeBuf>{}(a_m_k);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,7 +300,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
ck_tile::HostTensor<BDataTypeBuf> b_shuffle_host = [&]() {
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
std::cout << "Run with PermuteN" << std::endl;
|
||||
@@ -299,7 +313,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
}();
|
||||
// shuffled buffer B for device implementation
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataTypeBuf, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::permute_vectors_i4x4_b(b_shuffle_host);
|
||||
}
|
||||
@@ -307,16 +321,16 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataTypeBuf, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
ck_tile::HostTensor<BDataTypeBuf> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<GemmConfig,
|
||||
decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
ADataTypeBuf,
|
||||
BDataTypeBuf,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
@@ -343,8 +357,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
float ave_time = invoke_gemm<GemmConfig,
|
||||
Invoker,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ADataTypeCompute,
|
||||
BDataTypeCompute,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
@@ -371,8 +385,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K / ck_tile::numeric_traits<ADataType>::PackedSize +
|
||||
sizeof(BDataType) * N * K / ck_tile::numeric_traits<BDataType>::PackedSize +
|
||||
sizeof(ADataTypeBuf) * M * K / ck_tile::numeric_traits<ADataTypeBuf>::PackedSize +
|
||||
sizeof(BDataTypeBuf) * N * K / ck_tile::numeric_traits<BDataTypeBuf>::PackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
@@ -381,8 +395,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name
|
||||
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
|
||||
<< " A_Type=" << ck_tile::DataTypeTraits<ADataTypeBuf>::name
|
||||
<< " B_Type=" << ck_tile::DataTypeTraits<BDataTypeBuf>::name
|
||||
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
|
||||
@@ -397,17 +411,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
ck_tile::reference_gemm<ADataTypeCompute, BDataTypeCompute, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataTypeCompute, BDataTypeCompute, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataTypeBuf, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
@@ -421,12 +436,12 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
ADataTypeBuf* d_A = static_cast<ADataTypeBuf*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataTypeBuf* d_B = static_cast<BDataTypeBuf*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
ck_tile::reference_gemm_gpu<ADataTypeCompute,
|
||||
BDataTypeCompute,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
@@ -437,8 +452,9 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataTypeCompute, BDataTypeCompute, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
|
||||
}
|
||||
|
||||
@@ -447,8 +463,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
dump_gemm_json_results<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ADataTypeBuf,
|
||||
BDataTypeBuf,
|
||||
CDataType,
|
||||
GemmConfig,
|
||||
ck_tile::DataTypeTraits>(arg_parser.get_str("jsonfile"),
|
||||
|
||||
Reference in New Issue
Block a user