mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
changes suggested in PR review are made- removing comments and correcting copyright
This commit is contained in:
0
.pre-commit-config.yaml
Normal file → Executable file
0
.pre-commit-config.yaml
Normal file → Executable file
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -117,7 +117,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
@@ -162,18 +162,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
#if 0
|
||||
printf("B matrix:\n");
|
||||
for (int in = 0; in < N; in++)
|
||||
{
|
||||
for (int ik = 0; ik < K; ik++)
|
||||
{
|
||||
printf("%02x ", *(reinterpret_cast<uint8_t*>(&b_k_n(ik,in))));
|
||||
if(ik%8==7) printf("|");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -147,10 +147,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_)); // HS
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
dim3 grid_dim;
|
||||
if(arg.Grid_size < 0)
|
||||
@@ -193,25 +191,13 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
grid_dim,
|
||||
// dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
// dim3(gdx, gdy, gdz),
|
||||
grid_dim,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg);
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -477,7 +463,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
{
|
||||
// return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
|
||||
|
||||
return Argument{
|
||||
p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS
|
||||
}
|
||||
|
||||
@@ -1461,31 +1461,27 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
|
||||
// check if there's enough work for DP+ stream-k
|
||||
bool bigEnough = num_tiles > grid_size;
|
||||
// select between 1 tile and 2 tile sk
|
||||
// select between stream-k strategies
|
||||
uint32_t sk_tiles = 0;
|
||||
if(streamk_sel == 1)
|
||||
if(streamk_sel == 1) // 1 tile stream-k
|
||||
{
|
||||
sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 2)
|
||||
else if(streamk_sel == 2) // 2-tile stream-k
|
||||
{
|
||||
sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 3)
|
||||
else if(streamk_sel == 3) // 3-tile stream-k
|
||||
{
|
||||
sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
|
||||
: num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 4)
|
||||
else if(streamk_sel == 4) // 4-tile stream-k
|
||||
{
|
||||
sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
|
||||
: num_tiles;
|
||||
}
|
||||
sk_num_blocks = sk_tiles;
|
||||
// if(sk_tiles < sk_num_blocks)
|
||||
// {
|
||||
// sk_num_blocks = sk_tiles;
|
||||
// }
|
||||
// remaining tiles are DP tiles
|
||||
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
|
||||
|
||||
@@ -1508,7 +1504,6 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = sk_num_blocks;
|
||||
// dp_start_block_idx = ((sk_num_blocks + grid_size - 1) / grid_size) * grid_size;
|
||||
}
|
||||
|
||||
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
|
||||
@@ -1523,30 +1518,29 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("streamk_sel=%0d,grid_size=%0d, num_tiles:%d, dp_tiles:%d, sk_tiles:%u, "
|
||||
"sk_num_blocks:%d,dp_num_blocks:%d,sk_num_big_blocks:%d, "
|
||||
"sk_total_iters:%d, dp_start_block_idx:%d, "
|
||||
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
|
||||
" workspace(acc float):%u\n",
|
||||
streamk_sel,
|
||||
grid_size,
|
||||
// occupancy,
|
||||
// get_grid_dims(num_cu, occupancy).x,
|
||||
num_tiles,
|
||||
dp_tiles,
|
||||
get_sk_tiles(),
|
||||
sk_num_blocks,
|
||||
dp_num_blocks,
|
||||
sk_num_big_blocks,
|
||||
sk_total_iters,
|
||||
dp_start_block_idx,
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("streamk_sel=%0d,grid_size=%0d, num_tiles:%d, dp_tiles:%d, sk_tiles:%u, "
|
||||
"sk_num_blocks:%d,dp_num_blocks:%d,sk_num_big_blocks:%d, "
|
||||
"sk_total_iters:%d, dp_start_block_idx:%d, "
|
||||
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
|
||||
" workspace(acc float):%u\n",
|
||||
streamk_sel,
|
||||
grid_size,
|
||||
num_tiles,
|
||||
dp_tiles,
|
||||
get_sk_tiles(),
|
||||
sk_num_blocks,
|
||||
dp_num_blocks,
|
||||
sk_num_big_blocks,
|
||||
sk_total_iters,
|
||||
dp_start_block_idx,
|
||||
|
||||
k_iters_per_tile.get(),
|
||||
k_iters_per_big_block,
|
||||
reduction_start_block_idx,
|
||||
get_workspace_size(sizeof(float)));
|
||||
#endif
|
||||
k_iters_per_tile.get(),
|
||||
k_iters_per_big_block,
|
||||
reduction_start_block_idx,
|
||||
get_workspace_size(sizeof(float)));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
@@ -1656,90 +1650,6 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
|
||||
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
|
||||
n_tile_idx_with_adapt);
|
||||
|
||||
// adding gfx94x optimized
|
||||
// index_t block_1d_id = tile_idx;
|
||||
// const index_t N0 = n_tiles_value;
|
||||
// const index_t M0 = math::integer_divide_ceil(n * m / m, MPerBlock);
|
||||
// // index_t GroupNum = 8;
|
||||
// // index_t M01_ = 4;
|
||||
|
||||
// if(M0 == 1)
|
||||
// {
|
||||
// return make_tuple(0, block_1d_id);
|
||||
// }
|
||||
// else if(N0 == 1)
|
||||
// {
|
||||
// return make_tuple(block_1d_id, 0);
|
||||
// }
|
||||
// // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
// else
|
||||
// {
|
||||
// const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
|
||||
// const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
|
||||
// auto group_id_x = block_1d_id % GroupNum;
|
||||
// auto group_id_y = block_1d_id / GroupNum;
|
||||
// auto remap_block_1d_id =
|
||||
// group_id_x <= big_group_num
|
||||
// ? group_id_x * group_size + group_id_y
|
||||
// : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
|
||||
|
||||
// index_t idx_N0 = remap_block_1d_id % N0;
|
||||
// index_t idx_M0 = remap_block_1d_id / N0;
|
||||
|
||||
// const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
// index_t idx_M00 = idx_M0 / M01_;
|
||||
// index_t idx_M01 = idx_M0 % M01_;
|
||||
// index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
// /**
|
||||
// * idxN0
|
||||
// *
|
||||
// * |< mtx N >|
|
||||
// *
|
||||
// * NPerBlock NPerBlock NPerBlock NPerBlock
|
||||
// * N_0 N_1 N_2 N_3
|
||||
// * - |-----------|-----------|-----------|-----|-----|-
|
||||
// * ^ | - - 0 |/----> 2 | | | |
|
||||
// * | | | / | | | | | M_0 MPerBlock
|
||||
// * | M | /| | | | | |
|
||||
// * |-0---|---/-|-----|-----|-----------|-----|-----|-
|
||||
// * | 1 | / | | | blockid | | |
|
||||
// * idxM0 | | | / | V | 5 | | | M_1 MPerBlock
|
||||
// * | - V 1 | - 3 | | | |
|
||||
// * |-----------|-----------|-----------|-----|-----|-
|
||||
// * mtx M | | | | | |
|
||||
// * | | | | | | M_2 MPerBlock
|
||||
// * | | | | | |
|
||||
// * |-----------|-----------|-----------|-----|-----|-
|
||||
// * | | | | | |
|
||||
// * | | | | | | M_3 MPerBlock
|
||||
// * | | | | | |
|
||||
// * |-----------|-----------|-----------|-----|-----|-
|
||||
// * V | | | | | |
|
||||
// * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
|
||||
// * | | | | | |
|
||||
// * |-----------|-----------|-----------|-----|-----|-
|
||||
// * Example:
|
||||
// * assume:
|
||||
// * M0 = 5
|
||||
// * N0 = 4
|
||||
// * block_1d_id = 5
|
||||
// * M01 = 2
|
||||
// *
|
||||
// * idx_N0 = 1
|
||||
// * idx_M0 = 1
|
||||
// * M01_adapt = 2
|
||||
// * idx_M00 = 0
|
||||
// * idx_M01 = 1
|
||||
// * idx_N0_M01_local = 5
|
||||
// * output {1, 2}
|
||||
// */
|
||||
|
||||
// return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
// idx_N0_M01_local / M01_adapt);
|
||||
//}
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
|
||||
|
||||
@@ -32,22 +32,13 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg);
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -62,8 +53,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
@@ -71,17 +61,8 @@ __global__ void
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg);
|
||||
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -155,15 +136,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// __host__ static auto CalculateGridSize(index_t M, index_t N) //, index_t KBatch)
|
||||
// {
|
||||
// // return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
// // return ((Block2CTileMap::CalculateGridSize(M, N)) * KBatch);
|
||||
// // return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
|
||||
// return Block2CTileMap::CalculateGridSize(M, N);
|
||||
// }
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
@@ -995,10 +967,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else
|
||||
{
|
||||
// constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
// auto K_t = KReadVec;
|
||||
// auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
|
||||
if(karg.K <= 0) // HS
|
||||
|
||||
if(karg.K <= 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1103,10 +1073,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
// if(karg.KBatch > 1)
|
||||
// {
|
||||
// return false;
|
||||
// }
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
@@ -1152,16 +1118,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
// if arch = gfx942
|
||||
// using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
StreamKReductionStrategy::Atomic,
|
||||
8,
|
||||
4>; // HS
|
||||
4>;
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
@@ -1177,43 +1139,39 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
// Provide a value for TileSwizzleSubM_
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel); // HS
|
||||
uint32_t iter_start, iter_end; // HS
|
||||
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; // HS
|
||||
index_t num_k_block_main_loop; // HS
|
||||
problem.Streamk_sel);
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block;
|
||||
index_t num_k_block_main_loop;
|
||||
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
{
|
||||
// for(unsigned int kbatch_id = 0; kbatch_id < static_cast<unsigned
|
||||
// int>(problem.KBatch);
|
||||
// kbatch_id++)
|
||||
|
||||
is_sk_block =
|
||||
static_cast<uint32_t>(block_idx) < block_2_ctile_map_streamk.sk_num_blocks;
|
||||
is_dp_block =
|
||||
static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx &&
|
||||
static_cast<uint32_t>(block_idx) <
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx; // HS
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); // HS
|
||||
num_k_block_main_loop = iter_end - iter_start; // HS
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
while(true)
|
||||
{
|
||||
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
|
||||
block_2_ctile_map_streamk.get_current_iter_length(
|
||||
iter_start, iter_end, num_k_block_main_loop)); // HS
|
||||
uint32_t tile_idx, iter_offset; // HS
|
||||
iter_start, iter_end, num_k_block_main_loop));
|
||||
uint32_t tile_idx, iter_offset;
|
||||
block_2_ctile_map_streamk.get_tile_idx_with_offset(
|
||||
iter_end - 1, tile_idx, iter_offset); // HS
|
||||
iter_offset =
|
||||
__builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); // HS
|
||||
iter_end - 1, tile_idx, iter_offset);
|
||||
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
|
||||
problem.MPadded,
|
||||
@@ -1237,17 +1195,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid /*+ splitk_batch_offset.a_k_split_offset*/,
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid /*+ splitk_batch_offset.b_k_split_offset*/,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// const auto block_work_idx =
|
||||
// block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx));
|
||||
auto block_work_idx =
|
||||
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); // HS
|
||||
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
@@ -1260,7 +1214,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
|
||||
|
||||
const index_t k0_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(iter_offset * AK0Number); // HS
|
||||
__builtin_amdgcn_readfirstlane(iter_offset * AK0Number);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
@@ -1298,7 +1252,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), // HS
|
||||
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1361,7 +1315,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock); // HS:AK0*KPadded/KPerBlock
|
||||
KPerBlock); :AK0*KPadded/KPerBlock
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
@@ -1607,7 +1561,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
// make sure next loop LDS is ready for use
|
||||
block_sync_lds(); // HS
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1627,13 +1581,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size); // HS
|
||||
uint32_t iter_start, iter_end; // HS
|
||||
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; // HS
|
||||
index_t num_k_block_main_loop; // HS
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(
|
||||
problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size);
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block;
|
||||
index_t num_k_block_main_loop;
|
||||
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
@@ -1644,21 +1596,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
is_dp_block =
|
||||
static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx &&
|
||||
static_cast<uint32_t>(block_idx) <
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx; // HS
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); // HS
|
||||
num_k_block_main_loop = iter_end - iter_start; // HS
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
{
|
||||
|
||||
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
|
||||
block_2_ctile_map_streamk.get_current_iter_length(
|
||||
iter_start, iter_end, num_k_block_main_loop)); // HS
|
||||
uint32_t tile_idx, iter_offset; // HS
|
||||
iter_start, iter_end, num_k_block_main_loop));
|
||||
uint32_t tile_idx, iter_offset;
|
||||
block_2_ctile_map_streamk.get_tile_idx_with_offset(
|
||||
iter_end - 1, tile_idx, iter_offset); // HS
|
||||
iter_offset =
|
||||
__builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); // HS
|
||||
iter_end - 1, tile_idx, iter_offset);
|
||||
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
|
||||
problem.MPadded,
|
||||
@@ -1683,16 +1634,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid /*+ splitk_batch_offset.a_k_split_offset*/,
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid /*+ splitk_batch_offset.b_k_split_offset*/,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
// const auto block_work_idx =
|
||||
// block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx));
|
||||
auto block_work_idx =
|
||||
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); // HS
|
||||
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
@@ -1704,7 +1651,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
|
||||
const index_t k0_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(iter_offset * AK0Number); // HS
|
||||
__builtin_amdgcn_readfirstlane(iter_offset * AK0Number);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
|
||||
@@ -1742,7 +1689,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), // HS
|
||||
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1773,7 +1720,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), // HS
|
||||
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -237,306 +237,6 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
// #endif
|
||||
// #ifdef CK_ENABLE_FP16
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
|
||||
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
// std::vector<std::unique_ptr<
|
||||
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
|
||||
// PassThrough>>>& instances);
|
||||
// #endif
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
@@ -626,158 +326,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
// if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
|
||||
// is_same_v<CDataType, half_t>)
|
||||
// {
|
||||
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
// is_same_v<CLayout, Row>)
|
||||
// {
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
// }
|
||||
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
// is_same_v<CLayout, Row>)
|
||||
// {
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
// }
|
||||
// }
|
||||
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
|
||||
// is_same_v<CDataType, half_t>)
|
||||
// {
|
||||
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
// is_same_v<CLayout, Row>)
|
||||
// {
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
// }
|
||||
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
// is_same_v<CLayout, Row>)
|
||||
// {
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
// }
|
||||
// }
|
||||
// #endif
|
||||
// #ifdef CK_ENABLE_FP16
|
||||
// if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
|
||||
// is_same_v<CDataType, bhalf_t>)
|
||||
// {
|
||||
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
// is_same_v<CLayout, Row>)
|
||||
// {
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
// }
|
||||
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
// is_same_v<CLayout, Row>)
|
||||
// {
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
// op_ptrs);
|
||||
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
// op_ptrs);
|
||||
// }
|
||||
// }
|
||||
// #endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -21,70 +21,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
)
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp)
|
||||
|
||||
add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES})
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -21,70 +21,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
)
|
||||
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp)
|
||||
|
||||
add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES})
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -153,7 +153,9 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
std::vector<int> grid_size_list = {38, 76, 114, 152, 190, 228, 266, 304, 342, 380};
|
||||
std::vector<int> streamk_sel_list = {0, 1, 2, 3, 4};
|
||||
std::vector<int> streamk_sel_list = {
|
||||
0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP,
|
||||
// 2:2-tile Stream-K + DP
|
||||
|
||||
if(Grid_size == -1)
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -145,30 +145,6 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
// {
|
||||
// return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
// }
|
||||
// else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
// {
|
||||
// return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
// {
|
||||
// return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
// }
|
||||
// else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
// {
|
||||
// return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
// {
|
||||
// return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{});
|
||||
// }
|
||||
// else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
// {
|
||||
// return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
|
||||
// }
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user