Merge commit '3900e1e7ceacfa32cb8d1522260ed30befd4dae3' into develop

This commit is contained in:
assistant-librarian[bot]
2026-01-26 19:16:22 +00:00
parent 06fb853279
commit 39405747ab
18 changed files with 2330 additions and 1180 deletions

View File

@@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.2.0
### Added
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.

View File

@@ -3,77 +3,21 @@
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
Tuple<>{}, // p_d0s_grid
arg.p_b1_grid + b1_batch_offset,
Tuple<>{}, // p_d1s_grid
arg.p_c_grid + c_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.b1_grid_desc,
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.c_element_op,
arg.block_2_ctile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
@@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
// to LPerWmma (A.k.a Gemm0 NPerWmma).
static constexpr index_t NPerWmma = LPerWmma;
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB0_(BatchStrideB0),
BatchStrideB1_(BatchStrideB1),
BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB0_;
index_t BatchStrideB1_;
index_t BatchStrideC_;
};
using DeviceGemmGemmCommonBase =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
Tuple<>, // D0sLayout
B1Layout,
Tuple<>, // D1sLayout
CLayout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CDataType,
Tuple<>, // D0sDataType
Tuple<>, // D1sDataType
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
CShuffleBlockTransferScalarPerVector_NPerBlock,
false>; // IsMultiD
// GridwiseOp
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
@@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
typename DeviceGemmGemmCommonBase::AGridDesc,
typename DeviceGemmGemmCommonBase::B0GridDesc,
Tuple<>, // Ds0GridDesc
B1GridDesc,
typename DeviceGemmGemmCommonBase::B1GridDesc,
Tuple<>, // Ds1GridDesc
CGridDesc_M_N,
typename DeviceGemmGemmCommonBase::CGridDesc_M_N,
// Tiling Family
MPerBlock,
LPerBlock,
@@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Transform::matrix_padder.PadN,
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
struct RawArg : public BaseArgument
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
DeviceOp,
GemmSpec,
ALayout,
B0layout,
Tuple<>, // D0sLayout
B1Layout,
Tuple<>, // D1sLayout
CLayout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CDataType,
Tuple<>, // D0sDataType,
Tuple<>, // D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
CShuffleBlockTransferScalarPerVector_NPerBlock,
false>; // IsMultiD
// Invoker
using Invoker = typename DeviceGemmGemmCommon::Invoker;
// Argument
using Argument = typename DeviceGemmGemmCommon::Argument;
static bool IsSupportedArgument(const Argument& arg)
{
using arr3 = std::array<ck::index_t, 3>;
RawArg(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
const B1DataType* p_b1_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
index_t StrideB1,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_b1_grid{p_b1_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
c_g_m_o_lengths = arr3{batch_count, M, O};
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
const B1DataType* p_b1_grid;
CDataType* p_c_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
arr3 c_g_m_o_lengths;
arr3 c_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
// Grid descriptors and other mem calculators
AGridDesc a_grid_desc;
B0GridDesc b0_grid_desc;
B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n;
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B layout must be Column\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{},
arg.b1_grid_desc,
Tuple<>{},
arg.c_grid_desc_m_n,
arg.block_2_ctile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto c_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tn = tail_number;
const auto kernel =
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b0,
@@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
std::array<const void*, DeviceGemmGemmCommonBase::NumD0Tensor> p_d0_grid{};
std::array<const void*, DeviceGemmGemmCommonBase::NumD1Tensor> p_d1_grid{};
std::array<index_t, DeviceGemmGemmCommonBase::NumD0Tensor> StrideD0s{}, BatchStrideD0s{};
std::array<index_t, DeviceGemmGemmCommonBase::NumD1Tensor> StrideD1s, BatchStrideD1s{};
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0_grid,
static_cast<const B1DataType*>(p_b1),
p_d1_grid,
static_cast<CDataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideC,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }

View File

@@ -0,0 +1,902 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include <iostream>
#include <cstdarg>
#include <type_traits>
#include <utility>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck/utility/integral_constant.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp,
typename GridwiseOp,
bool HasMainKBlockLoop,
TailNumber TailNum,
bool IsMultiD>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_e1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx)));
auto [p_d0s_grid, p_d1s_grid] = [&]() {
if constexpr(IsMultiD)
{
auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) {
static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) {
const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In));
grid_pointer(In) = arg_grid(In) + batch_offset;
});
return std::move(grid_pointer);
};
auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) {
return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx);
};
auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) {
return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx);
};
auto d0s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD0Tensor>{},
get_d0_base_ptr,
arg.p_d0s_grid,
GridwiseOp::MakeD0sGridPointer());
auto d1s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD1Tensor>{},
get_d1_base_ptr,
arg.p_d1s_grid,
GridwiseOp::MakeD1sGridPointer());
return std::make_pair(d0s_grid, d1s_grid);
}
else
{
return std::make_pair(Tuple<>{}, Tuple<>{});
}
}();
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
p_d0s_grid,
arg.p_b1_grid + b1_batch_offset,
p_d1s_grid,
arg.p_c_e1_grid + c_e1_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.cde1_element_op,
arg.block_2_etile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
template <typename DeviceOp,
GemmSpecialization GemmSpec,
typename ALayout,
typename B0layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename CE1Layout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock, // Gemm0NPerBlock
ck::index_t KPerBlock, // Gemm0KPerBlock
ck::index_t NPerBlock, // Gemm1NPerBlock
typename ADataType,
typename B0DataType,
typename B1DataType,
typename AccDataType,
typename CE1DataType,
typename D0sDataType,
typename D1sDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t L1, // B1K1
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
BlockGemmPipelineVersion BlkGemmPipelineVer,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t B0BlockTransferSrcVectorDim,
ck::index_t B0BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t CDE0BlockTransferSrcScalarPerVector,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool IsMultiD = false>
struct DeviceGemmGemm_Wmma_CShuffleV3_Common
{
static constexpr ck::index_t NumD0Tensor = []() {
if constexpr(IsMultiD)
{
return DeviceOp::NumD0Tensor;
}
return 0;
}();
static constexpr ck::index_t NumD1Tensor = []() {
if constexpr(IsMultiD)
{
return DeviceOp::NumD1Tensor;
}
return 0;
}();
struct GridDescriptorCreator
{
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
__host__ __device__ static auto
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
}
__host__ __device__ static auto MakeD0sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
},
Number<NumD0Tensor>{});
}
__host__ __device__ static auto MakeD1sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
},
Number<NumD1Tensor>{});
}
__host__ __device__ static auto
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
}
};
using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {}));
using D0sGridDesc =
remove_cvref_t<decltype(GridDescriptorCreator::MakeD0sGridDescriptor({}, {}))>;
using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {}));
using D1sGridDesc =
remove_cvref_t<decltype(GridDescriptorCreator::MakeD1sGridDescriptor({}, {}))>;
using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {}));
using CGridDesc_M_N =
decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB0_(BatchStrideB0),
BatchStrideB1_(BatchStrideB1),
BatchStrideC_E1_(BatchStrideC)
{
}
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
BatchStrideD0s_(BatchStrideD0s),
BatchStrideB1_(BatchStrideB1),
BatchStrideD1s_(BatchStrideD1s),
BatchStrideC_E1_(BatchStrideE1)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_E1_);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx,
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideC_E1_;
};
};
template <typename DeviceOp,
GemmSpecialization GemmSpec,
typename ALayout,
typename B0layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename CE1Layout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock, // Gemm0NPerBlock
ck::index_t KPerBlock, // Gemm0KPerBlock
ck::index_t NPerBlock, // Gemm1NPerBlock
typename ADataType,
typename B0DataType,
typename B1DataType,
typename AccDataType,
typename CE1DataType,
typename D0sDataType,
typename D1sDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t L1, // B1K1
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
BlockGemmPipelineVersion BlkGemmPipelineVer,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t B0BlockTransferSrcVectorDim,
ck::index_t B0BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t CDE0BlockTransferSrcScalarPerVector,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool IsMultiD = false>
struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg
{
using GridwiseGemm = typename DeviceOp::GridwiseOp;
using Common =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
CE1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CE1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
IsMultiD>;
static constexpr auto NumD0Tensor = Common::NumD0Tensor;
static constexpr auto NumD1Tensor = Common::NumD1Tensor;
struct Argument : public BaseArgument
{
using arr3 = std::array<ck::index_t, 3>;
Argument(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
std::array<const void*, NumD0Tensor> p_d0s_grid_,
const B1DataType* p_b1_grid_,
std::array<const void*, NumD1Tensor> p_d1s_grid_,
CE1DataType* p_e1_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CDE1ElementwiseOperation cde1_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_d0s_grid{},
p_b1_grid{p_b1_grid_},
p_d1s_grid{},
p_c_e1_grid{p_e1_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
cde1_element_op{cde1_element_op_},
compute_base_ptr_of_batch{BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
e1_g_m_o_lengths = arr3{batch_count, M, O};
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths,
a_g_m_k_strides);
b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths,
b0_g_n_k_strides);
b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths,
b1_g_o_n_strides);
c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor(
e1_g_m_o_lengths, e1_g_m_o_strides);
c_e1_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_e1_grid_desc_m_n);
block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1);
if constexpr(IsMultiD)
{
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0s layout [batch_count, M, N]
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
// D0 pointer
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
});
// D0 desc
d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor(
d0s_g_m_n_lengths, d0s_g_m_n_strides);
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
// D1s layout [batch_count, M, O]
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
// D1 pointer
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
});
// D1 desc
d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor(
d1s_g_m_o_lengths, d1s_g_m_o_strides);
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d1s_grid_desc);
}
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
typename GridwiseGemm::D0sGridPointer p_d0s_grid;
const B1DataType* p_b1_grid;
typename GridwiseGemm::D1sGridPointer p_d1s_grid;
CE1DataType* p_c_e1_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
arr3 e1_g_m_o_lengths;
arr3 e1_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CDE1ElementwiseOperation cde1_element_op;
// Grid descriptors and other mem calculators
typename Common::AGridDesc a_grid_desc;
typename Common::B0GridDesc b0_grid_desc;
std::conditional_t<IsMultiD, typename Common::D0sGridDesc, Tuple<>> d0s_grid_desc;
typename Common::B1GridDesc b1_grid_desc;
typename Common::D1sGridDesc d1s_grid_desc;
std::conditional_t<
IsMultiD,
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
Tuple<>>
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
std::conditional_t<IsMultiD, typename Common::E1GridDesc, typename Common::CGridDesc_M_N>
c_e1_grid_desc_m_n;
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_e1_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map;
typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tail_num = decltype(tail_number)::value;
const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp,
GridwiseGemm,
has_loop,
tail_num,
IsMultiD>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static constexpr bool CheckDLayout()
{
bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>;
});
return valid;
}
static bool IsSupportedArgument(const Argument& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(is_same_v<CE1Layout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if constexpr(IsMultiD)
{
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B0 layout must be Column\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
{
print("DeviceOp: All D0s layout must be Row\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
{
print("DeviceOp: All D1s layout must be Row\n");
return false;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc,
arg.c_e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto cde1_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
? arg.b0_g_n_k_strides[2]
: arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? arg.b1_g_o_n_strides[2]
: arg.b1_g_o_n_strides[1];
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
// and the lowest dimension stride is hardcoded to 1
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
e1_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
}
else
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{},
arg.b1_grid_desc,
Tuple<>{},
arg.c_e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto c_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
? arg.b0_g_n_k_strides[2]
: arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? arg.b1_g_o_n_strides[2]
: arg.b1_g_o_n_strides[1];
const auto c_stride_lowest = arg.e1_g_m_o_strides[2];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -3,91 +3,20 @@
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t e1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx)));
auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer();
auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer();
static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset;
});
static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) {
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset;
});
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
p_d0s_grid,
arg.p_b1_grid + b1_batch_offset,
p_d1s_grid,
arg.p_e1_grid + e1_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.cde1_element_op,
arg.block_2_etile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
// Computes:
// Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...)
// E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...)
@@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto I0 = Number<0>{};
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
// to LPerWmma (A.k.a Gemm0 NPerWmma).
static constexpr index_t NPerWmma = LPerWmma;
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
__host__ __device__ static auto
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
}
__host__ __device__ static auto MakeD0sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
},
Number<NumD0Tensor>{});
}
__host__ __device__ static auto MakeD1sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
},
Number<NumD1Tensor>{});
}
__host__ __device__ static auto
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using D0sGridDesc = remove_cvref_t<decltype(MakeD0sGridDescriptor({}, {}))>;
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using D1sGridDesc = remove_cvref_t<decltype(MakeD1sGridDescriptor({}, {}))>;
using E1GridDesc = decltype(MakeE1GridDescriptor({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA0_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
BatchStrideD0s_(BatchStrideD0s),
BatchStrideB1_(BatchStrideB1),
BatchStrideD1s_(BatchStrideD1s),
BatchStrideE1_(BatchStrideE1)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
}
template <index_t I>
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
}
private:
index_t BatchStrideA0_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideE1_;
};
using DeviceGemmGemmCommonBase =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
E1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>; // IsMultiD
// GridwiseOp
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
@@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
CDE1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
D0sGridDesc,
B1GridDesc,
D1sGridDesc,
E1GridDesc,
typename DeviceGemmGemmCommonBase::AGridDesc,
typename DeviceGemmGemmCommonBase::B0GridDesc,
typename DeviceGemmGemmCommonBase::D0sGridDesc,
typename DeviceGemmGemmCommonBase::B1GridDesc,
typename DeviceGemmGemmCommonBase::D1sGridDesc,
typename DeviceGemmGemmCommonBase::E1GridDesc,
// Tiling Family
MPerBlock,
LPerBlock,
@@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Transform::matrix_padder.PadN,
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
struct RawArg : public BaseArgument
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
E1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>; // IsMultiD
// Invoker
using Invoker = typename DeviceGemmGemmCommon::Invoker;
// Argument
using Argument = typename DeviceGemmGemmCommon::Argument;
static bool IsSupportedArgument(const Argument& arg)
{
using arr3 = std::array<ck::index_t, 3>;
RawArg(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
std::array<const void*, NumD0Tensor> p_d0s_grid_,
const B1DataType* p_b1_grid_,
std::array<const void*, NumD1Tensor> p_d1s_grid_,
E1DataType* p_e1_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CDE1ElementwiseOperation cde1_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_d0s_grid{},
p_b1_grid{p_b1_grid_},
p_d1s_grid{},
p_e1_grid{p_e1_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
cde1_element_op{cde1_element_op_},
compute_base_ptr_of_batch{BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
e1_g_m_o_lengths = arr3{batch_count, M, O};
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides);
e1_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e1_grid_desc_m_n);
block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1);
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0s layout [batch_count, M, N]
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
// D0 pointer
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
// D0 desc
d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]);
});
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
// D1s layout [batch_count, M, O]
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
// D1 pointer
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
// D1 desc
d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]);
});
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc);
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
typename GridwiseOp::D0sGridPointer p_d0s_grid;
const B1DataType* p_b1_grid;
typename GridwiseOp::D1sGridPointer p_d1s_grid;
E1DataType* p_e1_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
arr3 e1_g_m_o_lengths;
arr3 e1_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CDE1ElementwiseOperation cde1_element_op;
// Grid descriptors and other mem calculators
AGridDesc a_grid_desc;
B0GridDesc b0_grid_desc;
D0sGridDesc d0s_grid_desc;
B1GridDesc b1_grid_desc;
D1sGridDesc d1s_grid_desc;
typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
E1GridDesc e1_grid_desc_m_n;
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static constexpr bool CheckDLayout()
{
bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>;
});
return valid;
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
}
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B0 layout must be Column\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
{
print("DeviceOp: All D0s layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
{
print("DeviceOp: All D1s layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<E1Layout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc,
arg.e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto cde1_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
// and the lowest dimension stride is hardcoded to 1
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
e1_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tn = tail_number;
const auto kernel =
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3<DeviceOp,
GridwiseOp,
has_loop,
tn>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static auto MakeArgument(const ADataType* p_a0,
const B0DataType* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
@@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op)
{
return RawArg{p_a0, p_b0,
p_d0s, p_b1,
p_d1s, p_e1,
MRaw, NRaw,
KRaw, Gemm1NRaw,
Batch, StrideA0,
StrideB0, StrideD0s,
StrideB1, StrideD1s,
StrideE1, BatchStrideA0,
BatchStrideB0, BatchStrideD0s,
BatchStrideB1, BatchStrideD1s,
BatchStrideE1, a0_element_op,
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
return Argument{p_a0, p_b0,
p_d0s, p_b1,
p_d1s, p_e1,
MRaw, NRaw,
KRaw, Gemm1NRaw,
Batch, StrideA0,
StrideB0, StrideD0s,
StrideB1, StrideD1s,
StrideE1, BatchStrideA0,
BatchStrideB0, BatchStrideD0s,
BatchStrideB1, BatchStrideD1s,
BatchStrideE1, a0_element_op,
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
}
// polymorphic
@@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation c_element_op) override
{
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0s,
static_cast<const B1DataType*>(p_b1),
p_d1s,
static_cast<E1DataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideE1,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0s,
static_cast<const B1DataType*>(p_b1),
p_d1s,
static_cast<E1DataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideE1,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }

View File

@@ -227,7 +227,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<1>>{});
else
return make_static_tile_distribution(
tile_distribution_encoding< //
tile_distribution_encoding<
sequence<NWarps>,
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
sequence<K_Thread / AK1, K_Lane, AK1 / APackedSize>>,

View File

@@ -392,8 +392,4 @@ struct BlockReduce2D
InDataType reduce_init;
};
// deduction guide
template <typename T>
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D<T>;
} // namespace ck_tile

View File

@@ -40,7 +40,7 @@ struct BlockSoftmax2D
#endif
// compute row max
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
auto reduce_row_max = BlockReduce2D<decltype(x)>{x, -numeric<DataType>::infinity()};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
#else

143
script/tools/ck-build Executable file
View File

@@ -0,0 +1,143 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Build - Build Composable Kernel targets in Docker
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Build - Build Composable Kernel targets in Docker
Usage: ck-build [options] [target...]
Options:
-h, --help Show this help message
--name <name> Specify container name
--reconfigure Reconfigure CMake before building
-j <N> Parallel jobs (passed to ninja)
--clean Clean before building
Arguments:
target Target(s) to build (default: all)
Environment:
CK_CONTAINER_NAME - Override default container name
GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942)
Examples:
ck-build # Build all targets
ck-build test_amdgcn_mma # Build specific target
ck-build test_amdgcn_mma test_gemm # Build multiple targets
ck-build --reconfigure # Reconfigure CMake and build all
ck-build --clean test_amdgcn_mma # Clean and build target
ck-build -j 8 test_amdgcn_mma # Build with 8 parallel jobs
EOF
}
# Parse arguments
targets=()
reconfigure=false
clean=false
parallel_jobs=""
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
--name)
CONTAINER_NAME="$2"
shift 2
;;
--reconfigure)
reconfigure=true
shift
;;
--clean)
clean=true
shift
;;
-j)
parallel_jobs="-j $2"
shift 2
;;
*)
targets+=("$1")
shift
;;
esac
done
# Ensure container is running
if ! container_is_running "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' not running. Starting..."
"${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}"
echo ""
fi
# Configure CMake if needed or requested
if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then
echo "Detecting GPU target..."
GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}")
if [ "$reconfigure" = true ]; then
echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}"
else
echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}"
fi
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace || exit 1
rm -rf /workspace/build
mkdir /workspace/build
cd /workspace/build || exit 1
cmake .. -GNinja \
-DGPU_TARGETS=${GPU_TARGET_DETECTED} \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
-DBUILD_TESTING=ON 2>&1 | tail -30
"
echo ""
fi
# Clean if requested
if [ "$clean" = true ]; then
echo "Cleaning build directory..."
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace/build || exit 1
ninja clean
"
echo ""
fi
# Build targets
if [ ${#targets[@]} -eq 0 ]; then
echo "Building all configured targets..."
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace/build || exit 1
ninja ${parallel_jobs} 2>&1
"
else
echo "Building targets: ${targets[*]}"
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace/build || exit 1
ninja ${parallel_jobs} ${targets[*]} 2>&1
"
fi
echo ""
echo "Build complete ✓"

113
script/tools/ck-clean Executable file
View File

@@ -0,0 +1,113 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Clean - Clean build artifacts in Docker container
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Clean - Clean build artifacts in Docker container
Usage: ck-clean [options]
Options:
-h, --help Show this help message
--name <name> Specify container name
--all Remove entire build directory
-f, --force Force without confirmation
Environment:
CK_CONTAINER_NAME - Override default container name
Examples:
ck-clean # Clean build artifacts (ninja clean)
ck-clean --all # Remove entire build directory
ck-clean --force --all # Remove build directory without confirmation
EOF
}
# Parse arguments
remove_all=false
force=false
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
--name)
CONTAINER_NAME="$2"
shift 2
;;
--all)
remove_all=true
shift
;;
-f|--force)
force=true
shift
;;
*)
echo "Unknown option: $1"
show_help
exit 1
;;
esac
done
# Check if container is running
if ! container_is_running "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' not running"
echo "Start with: ck-start"
exit 1
fi
# Check if build directory exists
if ! docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then
echo "Build directory does not exist"
exit 0
fi
if [ "$remove_all" = true ]; then
# Remove entire build directory
if [ "$force" = false ]; then
read -p "Remove entire build directory? (y/N) " -n 1 -r
echo ""
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo "Cancelled"
exit 0
fi
fi
echo "Removing build directory..."
docker exec "${CONTAINER_NAME}" bash -c "rm -rf /workspace/build"
echo "Build directory removed ✓"
else
# Clean with ninja
if ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then
echo "Build not configured (build.ninja not found)"
echo "Use --all to remove build directory"
exit 1
fi
echo "Cleaning build artifacts..."
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace/build || exit 1
ninja clean
"
echo "Build artifacts cleaned ✓"
fi

111
script/tools/ck-exec Executable file
View File

@@ -0,0 +1,111 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Exec - Execute arbitrary commands in Docker container
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Exec - Execute arbitrary commands in Docker container
Usage: ck-exec [options] <command> [args...]
Options:
-h, --help Show this help message
--name <name> Specify container name
-w <dir> Working directory (default: /workspace)
-i, --interactive Interactive mode (allocate TTY)
Arguments:
command Command to execute (required)
args Arguments to the command
Environment:
CK_CONTAINER_NAME - Override default container name
Examples:
ck-exec rocm-smi # Run rocm-smi
ck-exec rocminfo # Run rocminfo
ck-exec ls -la build/bin # List build binaries
ck-exec -w /workspace/build ninja -t commands # Run ninja commands
ck-exec --interactive python3 # Interactive Python session
Common Commands:
ck-exec rocm-smi # Check GPU status
ck-exec rocminfo \| grep gfx # Check GPU architecture
ck-exec hipcc --version # Check HIP compiler version
ck-exec cmake --version # Check CMake version
ck-exec ninja -C build -t targets # List all build targets
EOF
}
# Parse arguments
workdir="/workspace"
interactive=false
command_args=()
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
--name)
CONTAINER_NAME="$2"
shift 2
;;
-w)
workdir="$2"
shift 2
;;
-i|--interactive)
interactive=true
shift
;;
*)
command_args+=("$1")
shift
;;
esac
done
# Validate command
if [ ${#command_args[@]} -eq 0 ]; then
echo "Error: command required"
echo ""
show_help
exit 1
fi
# Ensure container is running
if ! container_is_running "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' not running. Starting..."
"${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}"
echo ""
fi
# Build command string
cmd_string=""
for arg in "${command_args[@]}"; do
cmd_string="${cmd_string} $(printf '%q' "$arg")"
done
# Execute command
if [ "$interactive" = true ]; then
docker exec -it -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}"
else
docker exec -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}"
fi

134
script/tools/ck-logs Executable file
View File

@@ -0,0 +1,134 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Logs - View container logs and build output
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Logs - View container logs and build output
Usage: ck-logs [options] [container_name]
Options:
-h, --help Show this help message
--name <name> Specify container name
-f, --follow Follow log output
-n, --tail <N> Show last N lines (default: 100)
--cmake Show CMake configuration log
--build Show last build log
Arguments:
container_name Optional container name (default: ck_<username>_<branch>)
Environment:
CK_CONTAINER_NAME - Override default container name
Examples:
ck-logs # Show last 100 lines of container logs
ck-logs -f # Follow container logs
ck-logs -n 500 # Show last 500 lines
ck-logs --cmake # Show CMake configuration
ck-logs --build # Show build log
EOF
}
# Parse arguments
follow=false
tail_lines=100
show_cmake=false
show_build=false
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
--name)
CONTAINER_NAME="$2"
shift 2
;;
-f|--follow)
follow=true
shift
;;
-n|--tail)
tail_lines="$2"
shift 2
;;
--cmake)
show_cmake=true
shift
;;
--build)
show_build=true
shift
;;
*)
CONTAINER_NAME="$1"
shift
;;
esac
done
# Check if container exists
if ! container_exists "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' does not exist"
exit 1
fi
# Show CMake log
if [ "$show_cmake" = true ]; then
echo "CMake Configuration Log:"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if docker exec "${CONTAINER_NAME}" test -f /workspace/build/CMakeCache.txt 2>/dev/null; then
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace/build
echo 'GPU_TARGETS:' \$(grep 'GPU_TARGETS:' CMakeCache.txt | cut -d'=' -f2)
echo 'CMAKE_BUILD_TYPE:' \$(grep 'CMAKE_BUILD_TYPE:' CMakeCache.txt | cut -d'=' -f2)
echo 'CMAKE_CXX_COMPILER:' \$(grep 'CMAKE_CXX_COMPILER:' CMakeCache.txt | cut -d'=' -f2)
echo 'BUILD_TESTING:' \$(grep 'BUILD_TESTING:' CMakeCache.txt | cut -d'=' -f2)
"
else
echo "CMake not configured yet"
fi
exit 0
fi
# Show build log (last build output)
if [ "$show_build" = true ]; then
echo "Last Build Log:"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if docker exec "${CONTAINER_NAME}" test -f /workspace/build/.ninja_log 2>/dev/null; then
docker exec "${CONTAINER_NAME}" bash -c "tail -50 /workspace/build/.ninja_log"
else
echo "No build log found"
fi
exit 0
fi
# Show container logs
echo "Container Logs (${CONTAINER_NAME}):"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if [ "$follow" = true ]; then
docker logs -f "${CONTAINER_NAME}"
else
docker logs --tail "${tail_lines}" "${CONTAINER_NAME}"
fi

84
script/tools/ck-shell Executable file
View File

@@ -0,0 +1,84 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Shell - Open interactive shell in Docker container
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Shell - Open interactive shell in Docker container
Usage: ck-shell [options] [container_name]
Options:
-h, --help Show this help message
--name <name> Specify container name
-c <command> Execute command instead of interactive shell
Arguments:
container_name Optional container name (default: ck_<username>_<branch>)
Environment:
CK_CONTAINER_NAME - Override default container name
Examples:
ck-shell # Open interactive shell
ck-shell my_container # Open shell in specific container
ck-shell -c "rocm-smi" # Execute single command
ck-shell -c "cd build && ls bin" # Execute command in build directory
EOF
}
# Parse arguments
command=""
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
--name)
CONTAINER_NAME="$2"
shift 2
;;
-c)
command="$2"
shift 2
;;
*)
CONTAINER_NAME="$1"
shift
;;
esac
done
# Ensure container is running
if ! container_is_running "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' not running. Starting..."
"${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}"
echo ""
fi
# Execute command or open shell
if [ -n "$command" ]; then
echo "Executing: ${command}"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
docker exec "${CONTAINER_NAME}" bash -c "${command}"
else
echo "Opening shell in '${CONTAINER_NAME}' (type 'exit' to leave)..."
docker exec -it "${CONTAINER_NAME}" bash
fi

103
script/tools/ck-start Executable file
View File

@@ -0,0 +1,103 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Start - Start Docker container for Composable Kernel testing
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Start - Start Docker container for Composable Kernel testing
Usage: ck-start [options] [container_name]
Options:
-h, --help Show this help message
--image <image> Specify Docker image (overrides CK_DOCKER_IMAGE)
Arguments:
container_name Optional container name (default: ck_<username>_<branch>)
Environment:
CK_CONTAINER_NAME - Override default container name
CK_DOCKER_IMAGE - Override Docker image (default: rocm/composable_kernel:ck_ub24.04_rocm7.0.1)
Examples:
ck-start # Start container with default name
ck-start my_ck_container # Start container with custom name
ck-start --image rocm/composable_kernel:latest
EOF
}
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
--image)
export CK_DOCKER_IMAGE="$2"
shift 2
;;
*)
CONTAINER_NAME="$1"
shift
;;
esac
done
# Get Docker image
DOCKER_IMAGE=$(get_docker_image)
# Check if container exists and is running
if container_exists "${CONTAINER_NAME}"; then
if container_is_running "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' is already running"
docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd"
exit 0
else
echo "Starting existing container '${CONTAINER_NAME}'..."
docker start "${CONTAINER_NAME}"
echo "Container started"
docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd"
exit 0
fi
fi
# Create new container
echo "Creating new Docker container '${CONTAINER_NAME}'..."
echo "Docker image: ${DOCKER_IMAGE}"
echo "Project root: ${PROJECT_ROOT}"
echo ""
docker run -d \
--name "${CONTAINER_NAME}" \
--device=/dev/kfd --device=/dev/dri \
--security-opt seccomp=unconfined \
--group-add video \
-v "${PROJECT_ROOT}":/workspace \
-w /workspace \
"${DOCKER_IMAGE}" \
tail -f /dev/null
echo ""
echo "Container '${CONTAINER_NAME}' started successfully"
docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd"
# Show GPU info
echo ""
echo "GPU Information:"
docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -5 || echo 'No GPU detected'"

153
script/tools/ck-status Executable file
View File

@@ -0,0 +1,153 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Status - Check container status and information
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Status - Check container status and information
Usage: ck-status [options] [container_name]
Options:
-h, --help Show this help message
--name <name> Specify container name
--all Show all CK containers
-v, --verbose Show detailed information
Arguments:
container_name Optional container name (default: ck_<username>_<branch>)
Environment:
CK_CONTAINER_NAME - Override default container name
Examples:
ck-status # Check default container status
ck-status my_container # Check specific container
ck-status --all # Show all CK containers
ck-status -v # Show detailed information
EOF
}
# Parse arguments
show_all=false
verbose=false
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
--name)
CONTAINER_NAME="$2"
shift 2
;;
--all)
show_all=true
shift
;;
-v|--verbose)
verbose=true
shift
;;
*)
CONTAINER_NAME="$1"
shift
;;
esac
done
DOCKER_IMAGE=$(get_docker_image)
# Show all containers
if [ "$show_all" = true ]; then
echo "Composable Kernel Docker Containers:"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
username=$(get_username)
containers=$(docker ps -a --filter "name=ck_${username}_" --format "table {{.Names}}\t{{.Status}}\t{{.CreatedAt}}" 2>/dev/null || echo "")
if [ -z "$containers" ] || [ "$containers" = "NAMES STATUS CREATED AT" ]; then
echo "No CK containers found for user '${username}'"
else
echo "$containers"
fi
exit 0
fi
# Check specific container status
echo "Container: ${CONTAINER_NAME}"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if container_is_running "${CONTAINER_NAME}"; then
echo "Status: RUNNING ✓"
echo ""
docker ps --filter "name=^${CONTAINER_NAME}$" --format "table {{.Names}}\t{{.Status}}\t{{.Image}}"
if [ "$verbose" = true ]; then
echo ""
echo "Container Details:"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
docker inspect "${CONTAINER_NAME}" --format '
Image: {{.Config.Image}}
Created: {{.Created}}
Platform: {{.Platform}}
Mounts: {{range .Mounts}}
- {{.Source}} -> {{.Destination}}{{end}}
'
fi
echo ""
echo "GPU Information:"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -10 || echo 'No GPU detected'"
if [ "$verbose" = true ]; then
echo ""
echo "GPU Target:"
gpu_target=$(detect_gpu_target "${CONTAINER_NAME}")
echo " ${gpu_target}"
echo ""
echo "Build Status:"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then
if docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then
echo " CMake configured ✓"
echo " Build directory: /workspace/build"
# Count built test binaries
bin_count=$(docker exec "${CONTAINER_NAME}" bash -c "ls -1 /workspace/build/bin 2>/dev/null | wc -l" || echo "0")
echo " Test binaries: ${bin_count}"
else
echo " CMake not configured"
fi
else
echo " Build directory not found"
fi
fi
elif container_exists "${CONTAINER_NAME}"; then
echo "Status: STOPPED"
echo ""
echo "Start with: ck-start"
else
echo "Status: DOES NOT EXIST"
echo ""
echo "Create with: ck-start"
fi

141
script/tools/ck-stop Executable file
View File

@@ -0,0 +1,141 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Stop - Stop and remove Docker container
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Stop - Stop and remove Docker container
Usage: ck-stop [options] [container_name]
Options:
-h, --help Show this help message
-f, --force Force stop without confirmation
--all Stop all CK containers for this user
Arguments:
container_name Optional container name (default: ck_<username>_<branch>)
Environment:
CK_CONTAINER_NAME - Override default container name
Examples:
ck-stop # Stop default container
ck-stop my_ck_container # Stop specific container
ck-stop --all # Stop all user's CK containers
ck-stop --force # Stop without confirmation
EOF
}
# Parse arguments
force=false
stop_all=false
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
-f|--force)
force=true
shift
;;
--all)
stop_all=true
shift
;;
*)
CONTAINER_NAME="$1"
shift
;;
esac
done
# Function to stop a single container
stop_container() {
local name="$1"
if ! container_exists "${name}"; then
echo "Container '${name}' does not exist"
return 1
fi
echo "Stopping and removing container '${name}'..."
docker stop "${name}" 2>/dev/null || true
docker rm "${name}" 2>/dev/null || true
echo "Container '${name}' stopped and removed"
}
# Stop all user containers
if [ "$stop_all" = true ]; then
username=$(get_username)
containers=$(docker ps -a --filter "name=ck_${username}_" --format '{{.Names}}')
if [ -z "$containers" ]; then
echo "No CK containers found for user '${username}'"
exit 0
fi
echo "Found CK containers for user '${username}':"
echo "$containers"
echo ""
if [ "$force" = false ]; then
read -p "Stop and remove all these containers? (y/N) " -n 1 -r
echo ""
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo "Cancelled"
exit 0
fi
fi
echo ""
while IFS= read -r container; do
stop_container "$container"
done <<< "$containers"
echo ""
echo "All containers stopped and removed"
exit 0
fi
# Stop single container
if ! container_exists "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' does not exist"
exit 0
fi
# Show container info
if container_is_running "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' is currently running"
else
echo "Container '${CONTAINER_NAME}' exists but is stopped"
fi
# Confirm if not forced
if [ "$force" = false ]; then
read -p "Stop and remove container '${CONTAINER_NAME}'? (y/N) " -n 1 -r
echo ""
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo "Cancelled"
exit 0
fi
fi
stop_container "${CONTAINER_NAME}"

166
script/tools/ck-test Executable file
View File

@@ -0,0 +1,166 @@
#!/bin/bash
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Test - Build and test Composable Kernel in Docker
set -e
set -o pipefail
# Find script directory and load common utilities
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/common.sh"
# Initialize configuration
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
# Help message
show_help() {
cat << EOF
CK Test - Build and test Composable Kernel in Docker
Usage: ck-test [options] <test_name> [test_options]
Options:
-h, --help Show this help message
--name <name> Specify container name
--reconfigure Reconfigure CMake before building
--no-build Skip building, run test directly
Arguments:
test_name Name of test executable (required)
test_options Additional options passed to test (e.g., --gtest_filter=*)
Environment:
CK_CONTAINER_NAME - Override default container name
GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942)
Examples:
ck-test test_amdgcn_mma
ck-test test_amdgcn_mma --gtest_filter=*Fp16*
ck-test --name my_container test_amdgcn_mma
ck-test --reconfigure test_amdgcn_mma
EOF
}
# Parse arguments
test_name=""
reconfigure=false
no_build=false
test_options=()
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
show_help
exit 0
;;
--name)
CONTAINER_NAME="$2"
shift 2
;;
--reconfigure)
reconfigure=true
shift
;;
--no-build)
no_build=true
shift
;;
--gtest_*|--help)
test_options+=("$1")
shift
;;
*)
if [ -z "$test_name" ]; then
test_name="$1"
else
test_options+=("$1")
fi
shift
;;
esac
done
# Validate test name
if [ -z "$test_name" ]; then
echo "Error: test_name required"
echo ""
show_help
exit 1
fi
# Ensure container is running
if ! container_is_running "${CONTAINER_NAME}"; then
echo "Container '${CONTAINER_NAME}' not running. Starting..."
"${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}"
echo ""
fi
# Configure CMake if needed or requested
if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then
echo "Detecting GPU target..."
GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}")
if [ "$reconfigure" = true ]; then
echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}"
else
echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}"
fi
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace || exit 1
rm -rf /workspace/build
mkdir /workspace/build
cd /workspace/build || exit 1
cmake .. -GNinja \
-DGPU_TARGETS=${GPU_TARGET_DETECTED} \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
-DBUILD_TESTING=ON 2>&1 | tail -30
"
echo ""
fi
# Build test if needed (unless --no-build is specified)
if [ "$no_build" = false ]; then
if ! docker exec "${CONTAINER_NAME}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then
echo "Building ${test_name}..."
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace/build || exit 1
ninja ${test_name} 2>&1
"
echo ""
else
echo "Test executable found, rebuilding to ensure latest version..."
docker exec "${CONTAINER_NAME}" bash -c "
cd /workspace/build || exit 1
ninja ${test_name} 2>&1
"
echo ""
fi
fi
# Run test
echo "Running: ${test_name} ${test_options[*]}"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
# Build the command with proper quoting
cmd="cd /workspace/build && ./bin/${test_name}"
for opt in "${test_options[@]}"; do
cmd="${cmd} $(printf '%q' "$opt")"
done
docker exec "${CONTAINER_NAME}" bash -c "${cmd}"
exit_code=$?
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
if [ $exit_code -eq 0 ]; then
echo "Test completed successfully"
else
echo "Test failed with exit code: ${exit_code}"
fi
exit $exit_code

View File

@@ -13,13 +13,8 @@ class TestCkTileGemmPipelineCompV3
static constexpr bool check_data_type()
{
using Base = TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV3<T>>;
if constexpr(std::is_same_v<typename Base::ADataType, F8> &&
std::is_same_v<typename Base::BDataType, BF8>)
{
return false;
}
else if constexpr(std::is_same_v<typename Base::BLayout, Row> &&
std::is_same_v<typename Base::BDataType, I4>)
if constexpr(std::is_same_v<typename Base::BLayout, Row> &&
std::is_same_v<typename Base::BDataType, I4>)
{
return false;
}

View File

@@ -170,7 +170,7 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
@@ -180,7 +180,7 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
@@ -190,7 +190,7 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
@@ -200,7 +200,7 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>