mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '507d81c3af51b81f15b946a2a4bef7f594620292' into develop
This commit is contained in:
@@ -425,6 +425,11 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle<AL
|
||||
return false;
|
||||
}
|
||||
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
|
||||
@@ -40,14 +40,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// Full K needed for matrix B
|
||||
const index_t Kt = karg.K;
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -74,15 +82,23 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// Full K needed for matrix B
|
||||
const index_t Kt = karg.K;
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg);
|
||||
karg,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -658,25 +674,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
// b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * NLane / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
@@ -697,7 +694,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t c_reduce_offset;
|
||||
};
|
||||
|
||||
@@ -900,6 +896,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
@@ -1134,7 +1135,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock)
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
@@ -1226,7 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
k_id,
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -1465,10 +1467,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bpreshuffled =
|
||||
@@ -1491,7 +1495,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
problem,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bpreshuffled,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id);
|
||||
}
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
@@ -1509,7 +1514,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock)
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
@@ -1606,7 +1612,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
k_id,
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -1849,10 +1855,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bpreshuffled =
|
||||
@@ -1877,7 +1885,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
problem,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bpreshuffled,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -43,18 +43,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// Full K needed for matrix B
|
||||
const index_t Kt = karg.K;
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_grid,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
karg.c_element_op,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -79,11 +87,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// Full K needed for matrix B
|
||||
const index_t Kt = karg.K;
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_grid,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
@@ -91,7 +105,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
karg.c_element_op,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -691,16 +707,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
// KPack * NLane * KLane * K0 * N0
|
||||
b_k_split_offset = k_id * karg.KRead * NLane;
|
||||
}
|
||||
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
@@ -712,7 +718,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
};
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -1163,7 +1168,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
|
||||
Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
@@ -1176,7 +1183,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
|
||||
template <typename Block2CTileMap,
|
||||
@@ -1192,11 +1201,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
@@ -1293,7 +1304,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
k_id,
|
||||
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -1597,7 +1608,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
|
||||
Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
@@ -1611,7 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
|
||||
template <typename Block2CTileMap,
|
||||
@@ -1628,11 +1643,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
|
||||
@@ -1731,7 +1748,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
k_id,
|
||||
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
|
||||
103
profiler/include/profiler/common.hpp
Normal file
103
profiler/include/profiler/common.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename DataType, typename ComputeDataType = DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType, typename ComputeDataType = DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -69,19 +69,19 @@ template <typename A0DataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout>
|
||||
bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideE,
|
||||
int n_warmup,
|
||||
int n_iter,
|
||||
uint64_t rotating = 0)
|
||||
bool profile_gemm_blockscale_weightpreshuffle_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideE,
|
||||
int n_warmup,
|
||||
int n_iter,
|
||||
uint64_t rotating = 0)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
@@ -126,6 +126,26 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a0_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b0_k_n, BLayout{}, StrideB);
|
||||
StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE);
|
||||
|
||||
int total_gemm_needed =
|
||||
a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() +
|
||||
a1_m_k.GetElementSpaceSizeInBytes() + b1_k_n.GetElementSpaceSizeInBytes();
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -112,6 +113,28 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b_k_n, BLayout{}, StrideB);
|
||||
StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD0);
|
||||
StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD1);
|
||||
StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE);
|
||||
|
||||
int total_gemm_needed =
|
||||
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
|
||||
d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes();
|
||||
@@ -133,7 +156,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
|
||||
break;
|
||||
default:
|
||||
@@ -282,8 +305,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
is_same_v<EDataType, int8_t>))
|
||||
{
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-3;
|
||||
double atol = 5e-2;
|
||||
double rtol = get_rtol<EDataType>();
|
||||
double atol = get_atol<EDataType>();
|
||||
pass = pass & ck::utils::check_err(
|
||||
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -99,6 +100,26 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b_k_n, BLayout{}, StrideB);
|
||||
StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC);
|
||||
|
||||
std::size_t total_gemm_needed =
|
||||
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes();
|
||||
int rotating_count = std::max(
|
||||
@@ -317,8 +338,8 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
|
||||
is_same_v<CDataType, f8_t>)
|
||||
{
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-1;
|
||||
double atol = 1e-1;
|
||||
double rtol = get_rtol<CDataType>();
|
||||
double atol = get_atol<CDataType>();
|
||||
pass = pass & ck::utils::check_err(
|
||||
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
|
||||
}
|
||||
|
||||
@@ -5,92 +5,11 @@
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename DataType>
|
||||
inline constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -126,19 +126,19 @@ int profile_gemm_blockscale_weighpreshuffle(int argc, char* argv[])
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
bool pass = ck::profiler::profile_gemm_blockscale_weighpreshuffle_impl<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
scale_block_m,
|
||||
scale_block_n,
|
||||
scale_block_k,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(
|
||||
bool pass = ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
scale_block_m,
|
||||
scale_block_n,
|
||||
scale_block_k,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
|
||||
@@ -245,10 +245,13 @@ add_subdirectory(conv_util)
|
||||
add_subdirectory(reference_conv_fwd)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_add)
|
||||
add_subdirectory(gemm_blockscale_wp)
|
||||
add_subdirectory(gemm_layernorm)
|
||||
add_subdirectory(gemm_multi_abd)
|
||||
add_subdirectory(gemm_multiply_multiply_wp)
|
||||
add_subdirectory(gemm_split_k)
|
||||
add_subdirectory(gemm_universal)
|
||||
add_subdirectory(gemm_universal_preshuffle)
|
||||
add_subdirectory(gemm_b_scale)
|
||||
add_subdirectory(gemm_universal_streamk)
|
||||
add_subdirectory(gemm_reduce)
|
||||
|
||||
6
test/gemm_blockscale_wp/CMakeLists.txt
Normal file
6
test/gemm_blockscale_wp/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance)
|
||||
endif()
|
||||
endif()
|
||||
64
test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp
Normal file
64
test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBlockScaleWP_FP8_MK_NK : public ck::test::TestGemmBlockscaleWPCommon<
|
||||
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
std::tuple< F8, F32, F8, F32, F8, BF16>
|
||||
#endif
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmBlockScaleWP_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular0)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular1)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 512};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 4096;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
77
test/gemm_blockscale_wp/test_gemm_common.hpp
Normal file
77
test/gemm_blockscale_wp/test_gemm_common.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_blockscale_wp_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBlockscaleWPCommon : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using A0DataType = std::tuple_element_t<2, Tuple>;
|
||||
using A1DataType = std::tuple_element_t<3, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<4, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<5, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<6, Tuple>;
|
||||
using CDataType = std::tuple_element_t<7, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1;
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false;
|
||||
static constexpr index_t ScaleBlockM = 1;
|
||||
static constexpr index_t ScaleBlockN = 128;
|
||||
static constexpr index_t ScaleBlockK = 128;
|
||||
|
||||
void Run(const int M, const int N, const int K, int n_warmup = 1, int n_iter = 10)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
int StrideA = std::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = std::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideC = std::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
6
test/gemm_multiply_multiply_wp/CMakeLists.txt
Normal file
6
test/gemm_multiply_multiply_wp/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
add_gtest_executable(test_gemm_multiply_multiply_wp_xdl_fp8 test_gemm_multiply_multiply_wp_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multiply_multiply_wp_xdl_fp8 PRIVATE utility device_gemm_multiply_multiply_wp_instance)
|
||||
endif()
|
||||
endif()
|
||||
93
test/gemm_multiply_multiply_wp/test_gemm_common.hpp
Normal file
93
test/gemm_multiply_multiply_wp/test_gemm_common.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_multiply_multiply_wp_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmMultiplyMultiplyWPCommon : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<2, Tuple>;
|
||||
using D1Layout = std::tuple_element_t<3, Tuple>;
|
||||
using ELayout = Row;
|
||||
using ADataType = std::tuple_element_t<4, Tuple>;
|
||||
using BDataType = std::tuple_element_t<5, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<6, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<7, Tuple>;
|
||||
using D1DataType = std::tuple_element_t<8, Tuple>;
|
||||
using EDataType = std::tuple_element_t<9, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // decimal value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2, 4}; }
|
||||
|
||||
void Run(const int M, const int N, const int K)
|
||||
{
|
||||
for(size_t i = 0; i < k_batches_.size(); i++)
|
||||
{
|
||||
RunSingle(M, N, K, k_batches_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(
|
||||
const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
int StrideA = std::is_same_v<remove_cvref_t<ALayout>, Row> ? K : M;
|
||||
int StrideB = std::is_same_v<remove_cvref_t<BLayout>, Row> ? N : K;
|
||||
int StrideD0 = std::is_same_v<remove_cvref_t<D0Layout>, Row> ? N : M;
|
||||
int StrideD1 = std::is_same_v<remove_cvref_t<D1Layout>, Row> ? N : M;
|
||||
int StrideE = std::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
ck::profiler::profile_gemm_multiply_multiply_weight_preshuffle_impl<ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
D1Layout,
|
||||
ELayout>(
|
||||
verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideD0,
|
||||
StrideD1,
|
||||
StrideE,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmMultiplyMultiplyWP_FP8_MK_NK
|
||||
: public ck::test::TestGemmMultiplyMultiplyWPCommon<
|
||||
typename tuple_concat<std::tuple<Row, Col, Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
std::tuple< F8, F8, F8, F32, F32, F16>,
|
||||
std::tuple< F8, F8, F8, F32, F32, BF16>
|
||||
#endif
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmMultiplyMultiplyWP_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular0)
|
||||
{
|
||||
std::vector<int> Ms{128, 224, 256, 448, 512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular1)
|
||||
{
|
||||
std::vector<int> Ms{128, 224, 256, 448, 512};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 4096;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular2)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 512};
|
||||
constexpr int N = 448;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
6
test/gemm_universal_preshuffle/CMakeLists.txt
Normal file
6
test/gemm_universal_preshuffle/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
endif()
|
||||
79
test/gemm_universal_preshuffle/test_gemm_common.hpp
Normal file
79
test/gemm_universal_preshuffle/test_gemm_common.hpp
Normal file
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_universal_preshuffle_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversalPreshuffleCommon : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using ADataType = std::tuple_element_t<2, Tuple>;
|
||||
using BDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<4, Tuple>;
|
||||
using CDataType = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1;
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false;
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2, 4}; }
|
||||
|
||||
void Run(const int M, const int N, const int K)
|
||||
{
|
||||
for(size_t i = 0; i < k_batches_.size(); i++)
|
||||
{
|
||||
RunSingle(M, N, K, k_batches_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(
|
||||
const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
int StrideA = std::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = std::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideC = std::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
all_success = all_success &
|
||||
ck::profiler::profile_gemm_universal_preshuffle_impl<ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversalPreshuffle_FP8_MK_NK
|
||||
: public ck::test::TestGemmUniversalPreshuffleCommon<
|
||||
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
std::tuple< F8, F8, F8, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>
|
||||
#endif
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversalPreshuffle_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular0)
|
||||
{
|
||||
std::vector<int> Ms{128, 224, 256, 448, 512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular1)
|
||||
{
|
||||
std::vector<int> Ms{128, 224, 256, 448, 512};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 4096;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular2)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 512};
|
||||
constexpr int N = 448;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
Reference in New Issue
Block a user