mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Merge commit '3900e1e7ceacfa32cb8d1522260ed30befd4dae3' into develop
This commit is contained in:
@@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
## Composable Kernel 1.2.0 for ROCm 7.2.0
|
||||
|
||||
### Added
|
||||
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
|
||||
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
|
||||
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
|
||||
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
|
||||
|
||||
@@ -3,77 +3,21 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
Tuple<>{}, // p_d0s_grid
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
Tuple<>{}, // p_d1s_grid
|
||||
arg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.c_element_op,
|
||||
arg.block_2_ctile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// MN = MK * KL * LN
|
||||
// ^^^^^^ (Acc0)
|
||||
@@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
// to LPerWmma (A.k.a Gemm0 NPerWmma).
|
||||
static constexpr index_t NPerWmma = LPerWmma;
|
||||
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
|
||||
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB0_;
|
||||
index_t BatchStrideB1_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
using DeviceGemmGemmCommonBase =
|
||||
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
Tuple<>, // D0sLayout
|
||||
B1Layout,
|
||||
Tuple<>, // D1sLayout
|
||||
CLayout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
Tuple<>, // D0sDataType
|
||||
Tuple<>, // D1sDataType
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
false>; // IsMultiD
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
|
||||
@@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::AGridDesc,
|
||||
typename DeviceGemmGemmCommonBase::B0GridDesc,
|
||||
Tuple<>, // Ds0GridDesc
|
||||
B1GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::B1GridDesc,
|
||||
Tuple<>, // Ds1GridDesc
|
||||
CGridDesc_M_N,
|
||||
typename DeviceGemmGemmCommonBase::CGridDesc_M_N,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
@@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Transform::matrix_padder.PadN,
|
||||
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
|
||||
struct RawArg : public BaseArgument
|
||||
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
|
||||
DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
Tuple<>, // D0sLayout
|
||||
B1Layout,
|
||||
Tuple<>, // D1sLayout
|
||||
CLayout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
Tuple<>, // D0sDataType,
|
||||
Tuple<>, // D1sDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
false>; // IsMultiD
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmGemmCommon::Invoker;
|
||||
|
||||
// Argument
|
||||
using Argument = typename DeviceGemmGemmCommon::Argument;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
RawArg(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
index_t StrideB1,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
c_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
|
||||
}
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
CDataType* p_c_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
arr3 c_g_m_o_lengths;
|
||||
arr3 c_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
AGridDesc a_grid_desc;
|
||||
B0GridDesc b0_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.c_grid_desc_m_n,
|
||||
arg.block_2_ctile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto c_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest =
|
||||
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
|
||||
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
|
||||
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::RawArg;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tn = tail_number;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
@@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
std::array<const void*, DeviceGemmGemmCommonBase::NumD0Tensor> p_d0_grid{};
|
||||
std::array<const void*, DeviceGemmGemmCommonBase::NumD1Tensor> p_d1_grid{};
|
||||
std::array<index_t, DeviceGemmGemmCommonBase::NumD0Tensor> StrideD0s{}, BatchStrideD0s{};
|
||||
std::array<index_t, DeviceGemmGemmCommonBase::NumD1Tensor> StrideD1s, BatchStrideD1s{};
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
p_d0_grid,
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
p_d1_grid,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideD0s,
|
||||
StrideB1,
|
||||
StrideD1s,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
@@ -0,0 +1,902 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <cstdarg>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck/utility/integral_constant.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp,
|
||||
typename GridwiseOp,
|
||||
bool HasMainKBlockLoop,
|
||||
TailNumber TailNum,
|
||||
bool IsMultiD>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_e1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx)));
|
||||
|
||||
auto [p_d0s_grid, p_d1s_grid] = [&]() {
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) {
|
||||
static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) {
|
||||
const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In));
|
||||
grid_pointer(In) = arg_grid(In) + batch_offset;
|
||||
});
|
||||
return std::move(grid_pointer);
|
||||
};
|
||||
auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) {
|
||||
return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx);
|
||||
};
|
||||
auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) {
|
||||
return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx);
|
||||
};
|
||||
auto d0s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD0Tensor>{},
|
||||
get_d0_base_ptr,
|
||||
arg.p_d0s_grid,
|
||||
GridwiseOp::MakeD0sGridPointer());
|
||||
auto d1s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD1Tensor>{},
|
||||
get_d1_base_ptr,
|
||||
arg.p_d1s_grid,
|
||||
GridwiseOp::MakeD1sGridPointer());
|
||||
return std::make_pair(d0s_grid, d1s_grid);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_pair(Tuple<>{}, Tuple<>{});
|
||||
}
|
||||
}();
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
p_d0s_grid,
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
p_d1s_grid,
|
||||
arg.p_c_e1_grid + c_e1_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.d0s_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.cde1_element_op,
|
||||
arg.block_2_etile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
template <typename DeviceOp,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename B0layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename CE1Layout,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t LPerBlock, // Gemm0NPerBlock
|
||||
ck::index_t KPerBlock, // Gemm0KPerBlock
|
||||
ck::index_t NPerBlock, // Gemm1NPerBlock
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename AccDataType,
|
||||
typename CE1DataType,
|
||||
typename D0sDataType,
|
||||
typename D1sDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t L1, // B1K1
|
||||
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
|
||||
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t B0BlockTransferSrcVectorDim,
|
||||
ck::index_t B0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B1BlockTransferSrcVectorDim,
|
||||
ck::index_t B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t CDE0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool IsMultiD = false>
|
||||
struct DeviceGemmGemm_Wmma_CShuffleV3_Common
|
||||
{
|
||||
static constexpr ck::index_t NumD0Tensor = []() {
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
return DeviceOp::NumD0Tensor;
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
static constexpr ck::index_t NumD1Tensor = []() {
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
return DeviceOp::NumD1Tensor;
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
struct GridDescriptorCreator
|
||||
{
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
|
||||
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
|
||||
{
|
||||
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeD0sGridDescriptor(
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeD1sGridDescriptor(
|
||||
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_lengths_vec,
|
||||
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
|
||||
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
|
||||
{
|
||||
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
|
||||
}
|
||||
};
|
||||
|
||||
using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {}));
|
||||
using D0sGridDesc =
|
||||
remove_cvref_t<decltype(GridDescriptorCreator::MakeD0sGridDescriptor({}, {}))>;
|
||||
using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {}));
|
||||
using D1sGridDesc =
|
||||
remove_cvref_t<decltype(GridDescriptorCreator::MakeD1sGridDescriptor({}, {}))>;
|
||||
using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {}));
|
||||
using CGridDesc_M_N =
|
||||
decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideC_E1_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1)
|
||||
: BatchStrideA_(BatchStrideA0),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideD0s_(BatchStrideD0s),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideD1s_(BatchStrideD1s),
|
||||
BatchStrideC_E1_(BatchStrideE1)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_E1_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d0_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx,
|
||||
Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB0_;
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
|
||||
index_t BatchStrideB1_;
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
|
||||
index_t BatchStrideC_E1_;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename DeviceOp,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename B0layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename CE1Layout,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t LPerBlock, // Gemm0NPerBlock
|
||||
ck::index_t KPerBlock, // Gemm0KPerBlock
|
||||
ck::index_t NPerBlock, // Gemm1NPerBlock
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename AccDataType,
|
||||
typename CE1DataType,
|
||||
typename D0sDataType,
|
||||
typename D1sDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t L1, // B1K1
|
||||
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
|
||||
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t B0BlockTransferSrcVectorDim,
|
||||
ck::index_t B0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B1BlockTransferSrcVectorDim,
|
||||
ck::index_t B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t CDE0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool IsMultiD = false>
|
||||
struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg
|
||||
{
|
||||
using GridwiseGemm = typename DeviceOp::GridwiseOp;
|
||||
using Common =
|
||||
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
CE1Layout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
CE1DataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
IsMultiD>;
|
||||
|
||||
static constexpr auto NumD0Tensor = Common::NumD0Tensor;
|
||||
static constexpr auto NumD1Tensor = Common::NumD1Tensor;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
std::array<const void*, NumD0Tensor> p_d0s_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
std::array<const void*, NumD1Tensor> p_d1s_grid_,
|
||||
CE1DataType* p_e1_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
std::array<index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CDE1ElementwiseOperation cde1_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_d0s_grid{},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_d1s_grid{},
|
||||
p_c_e1_grid{p_e1_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
cde1_element_op{cde1_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
e1_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths,
|
||||
a_g_m_k_strides);
|
||||
b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths,
|
||||
b0_g_n_k_strides);
|
||||
b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths,
|
||||
b1_g_o_n_strides);
|
||||
c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor(
|
||||
e1_g_m_o_lengths, e1_g_m_o_strides);
|
||||
c_e1_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_e1_grid_desc_m_n);
|
||||
block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1);
|
||||
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
|
||||
// D0s layout [batch_count, M, N]
|
||||
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
|
||||
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
|
||||
|
||||
// D0 pointer
|
||||
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
|
||||
});
|
||||
// D0 desc
|
||||
d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor(
|
||||
d0s_g_m_n_lengths, d0s_g_m_n_strides);
|
||||
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
|
||||
|
||||
// D1s layout [batch_count, M, O]
|
||||
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
|
||||
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
|
||||
|
||||
// D1 pointer
|
||||
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
|
||||
});
|
||||
// D1 desc
|
||||
d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor(
|
||||
d1s_g_m_o_lengths, d1s_g_m_o_strides);
|
||||
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
d1s_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
typename GridwiseGemm::D0sGridPointer p_d0s_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
typename GridwiseGemm::D1sGridPointer p_d1s_grid;
|
||||
CE1DataType* p_c_e1_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
|
||||
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
|
||||
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
|
||||
arr3 e1_g_m_o_lengths;
|
||||
arr3 e1_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CDE1ElementwiseOperation cde1_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
typename Common::AGridDesc a_grid_desc;
|
||||
typename Common::B0GridDesc b0_grid_desc;
|
||||
std::conditional_t<IsMultiD, typename Common::D0sGridDesc, Tuple<>> d0s_grid_desc;
|
||||
typename Common::B1GridDesc b1_grid_desc;
|
||||
typename Common::D1sGridDesc d1s_grid_desc;
|
||||
std::conditional_t<
|
||||
IsMultiD,
|
||||
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Tuple<>>
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
std::conditional_t<IsMultiD, typename Common::E1GridDesc, typename Common::CGridDesc_M_N>
|
||||
c_e1_grid_desc_m_n;
|
||||
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_e1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map;
|
||||
|
||||
typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
/// @brief Helper structure responsible for kernel invocation.
|
||||
///
|
||||
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
|
||||
/// kernel function. It usually determines the launched grid size prepares kernel
|
||||
/// arguments as well as perform specific kernel configuration selection based on
|
||||
/// runtime arguments.
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tail_num = decltype(tail_number)::value;
|
||||
const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp,
|
||||
GridwiseGemm,
|
||||
has_loop,
|
||||
tail_num,
|
||||
IsMultiD>;
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
// check if DsLayout is supported
|
||||
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
|
||||
static constexpr bool CheckDLayout()
|
||||
{
|
||||
bool valid = true;
|
||||
// iterate over DLayout tuple
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
// if RefLayout and DLayout are same, keep valid true, otherwise false
|
||||
valid = valid && is_same_v<RefLayout, DLayout>;
|
||||
});
|
||||
return valid;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<CE1Layout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B0 layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
|
||||
{
|
||||
print("DeviceOp: All D0s layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
|
||||
{
|
||||
print("DeviceOp: All D1s layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.d0s_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.d1s_grid_desc,
|
||||
arg.c_e1_grid_desc_m_n,
|
||||
arg.block_2_etile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto cde1_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
|
||||
? arg.b0_g_n_k_strides[2]
|
||||
: arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
|
||||
? arg.b1_g_o_n_strides[2]
|
||||
: arg.b1_g_o_n_strides[1];
|
||||
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
|
||||
|
||||
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
|
||||
// and the lowest dimension stride is hardcoded to 1
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
e1_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.c_e1_grid_desc_m_n,
|
||||
arg.block_2_etile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto c_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
|
||||
? arg.b0_g_n_k_strides[2]
|
||||
: arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
|
||||
? arg.b1_g_o_n_strides[2]
|
||||
: arg.b1_g_o_n_strides[1];
|
||||
const auto c_stride_lowest = arg.e1_g_m_o_strides[2];
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -3,91 +3,20 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t e1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx)));
|
||||
|
||||
auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer();
|
||||
auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer();
|
||||
|
||||
static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) {
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
|
||||
p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset;
|
||||
});
|
||||
|
||||
static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) {
|
||||
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
|
||||
p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
p_d0s_grid,
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
p_d1s_grid,
|
||||
arg.p_e1_grid + e1_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.d0s_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.cde1_element_op,
|
||||
arg.block_2_etile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
// Computes:
|
||||
// Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...)
|
||||
// E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...)
|
||||
@@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
|
||||
// to LPerWmma (A.k.a Gemm0 NPerWmma).
|
||||
static constexpr index_t NPerWmma = LPerWmma;
|
||||
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
|
||||
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
|
||||
{
|
||||
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeD0sGridDescriptor(
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeD1sGridDescriptor(
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec,
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
|
||||
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
|
||||
{
|
||||
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
|
||||
}
|
||||
|
||||
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
|
||||
using D0sGridDesc = remove_cvref_t<decltype(MakeD0sGridDescriptor({}, {}))>;
|
||||
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
|
||||
using D1sGridDesc = remove_cvref_t<decltype(MakeD1sGridDescriptor({}, {}))>;
|
||||
using E1GridDesc = decltype(MakeE1GridDescriptor({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1)
|
||||
: BatchStrideA0_(BatchStrideA0),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideD0s_(BatchStrideD0s),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideD1s_(BatchStrideD1s),
|
||||
BatchStrideE1_(BatchStrideE1)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA0_;
|
||||
index_t BatchStrideB0_;
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
|
||||
index_t BatchStrideB1_;
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
|
||||
index_t BatchStrideE1_;
|
||||
};
|
||||
using DeviceGemmGemmCommonBase =
|
||||
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
E1DataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>; // IsMultiD
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
|
||||
@@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
CDE1ElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
D0sGridDesc,
|
||||
B1GridDesc,
|
||||
D1sGridDesc,
|
||||
E1GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::AGridDesc,
|
||||
typename DeviceGemmGemmCommonBase::B0GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::D0sGridDesc,
|
||||
typename DeviceGemmGemmCommonBase::B1GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::D1sGridDesc,
|
||||
typename DeviceGemmGemmCommonBase::E1GridDesc,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
@@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Transform::matrix_padder.PadN,
|
||||
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
|
||||
struct RawArg : public BaseArgument
|
||||
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
|
||||
DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
E1DataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>; // IsMultiD
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmGemmCommon::Invoker;
|
||||
|
||||
// Argument
|
||||
using Argument = typename DeviceGemmGemmCommon::Argument;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
RawArg(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
std::array<const void*, NumD0Tensor> p_d0s_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
std::array<const void*, NumD1Tensor> p_d1s_grid_,
|
||||
E1DataType* p_e1_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
std::array<index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CDE1ElementwiseOperation cde1_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_d0s_grid{},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_d1s_grid{},
|
||||
p_e1_grid{p_e1_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
cde1_element_op{cde1_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
e1_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides);
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e1_grid_desc_m_n);
|
||||
block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1);
|
||||
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
|
||||
// D0s layout [batch_count, M, N]
|
||||
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
|
||||
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
|
||||
|
||||
// D0 pointer
|
||||
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
|
||||
|
||||
// D0 desc
|
||||
d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]);
|
||||
});
|
||||
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
|
||||
|
||||
// D1s layout [batch_count, M, O]
|
||||
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
|
||||
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
|
||||
|
||||
// D1 pointer
|
||||
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
|
||||
|
||||
// D1 desc
|
||||
d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]);
|
||||
});
|
||||
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc);
|
||||
}
|
||||
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
typename GridwiseOp::D0sGridPointer p_d0s_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
typename GridwiseOp::D1sGridPointer p_d1s_grid;
|
||||
E1DataType* p_e1_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
|
||||
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
|
||||
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
|
||||
arr3 e1_g_m_o_lengths;
|
||||
arr3 e1_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CDE1ElementwiseOperation cde1_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
AGridDesc a_grid_desc;
|
||||
B0GridDesc b0_grid_desc;
|
||||
D0sGridDesc d0s_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
D1sGridDesc d1s_grid_desc;
|
||||
typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
E1GridDesc e1_grid_desc_m_n;
|
||||
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
// check if DsLayout is supported
|
||||
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
|
||||
static constexpr bool CheckDLayout()
|
||||
{
|
||||
bool valid = true;
|
||||
// iterate over DLayout tuple
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
// if RefLayout and DLayout are same, keep valid true, otherwise false
|
||||
valid = valid && is_same_v<RefLayout, DLayout>;
|
||||
});
|
||||
return valid;
|
||||
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B0 layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
|
||||
{
|
||||
print("DeviceOp: All D0s layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
|
||||
{
|
||||
print("DeviceOp: All D1s layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<E1Layout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.d0s_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.d1s_grid_desc,
|
||||
arg.e1_grid_desc_m_n,
|
||||
arg.block_2_etile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto cde1_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest =
|
||||
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
|
||||
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
|
||||
|
||||
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
|
||||
// and the lowest dimension stride is hardcoded to 1
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
e1_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
|
||||
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::RawArg;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tn = tail_number;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3<DeviceOp,
|
||||
GridwiseOp,
|
||||
has_loop,
|
||||
tn>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a0,
|
||||
const B0DataType* p_b0,
|
||||
std::array<const void*, NumD0Tensor> p_d0s,
|
||||
@@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op)
|
||||
{
|
||||
return RawArg{p_a0, p_b0,
|
||||
p_d0s, p_b1,
|
||||
p_d1s, p_e1,
|
||||
MRaw, NRaw,
|
||||
KRaw, Gemm1NRaw,
|
||||
Batch, StrideA0,
|
||||
StrideB0, StrideD0s,
|
||||
StrideB1, StrideD1s,
|
||||
StrideE1, BatchStrideA0,
|
||||
BatchStrideB0, BatchStrideD0s,
|
||||
BatchStrideB1, BatchStrideD1s,
|
||||
BatchStrideE1, a0_element_op,
|
||||
b0_element_op, cde0_element_op,
|
||||
b1_element_op, cde1_element_op};
|
||||
return Argument{p_a0, p_b0,
|
||||
p_d0s, p_b1,
|
||||
p_d1s, p_e1,
|
||||
MRaw, NRaw,
|
||||
KRaw, Gemm1NRaw,
|
||||
Batch, StrideA0,
|
||||
StrideB0, StrideD0s,
|
||||
StrideB1, StrideD1s,
|
||||
StrideE1, BatchStrideA0,
|
||||
BatchStrideB0, BatchStrideD0s,
|
||||
BatchStrideB1, BatchStrideD1s,
|
||||
BatchStrideE1, a0_element_op,
|
||||
b0_element_op, cde0_element_op,
|
||||
b1_element_op, cde1_element_op};
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
p_d0s,
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
p_d1s,
|
||||
static_cast<E1DataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideD0s,
|
||||
StrideB1,
|
||||
StrideD1s,
|
||||
StrideE1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
p_d0s,
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
p_d1s,
|
||||
static_cast<E1DataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideD0s,
|
||||
StrideB1,
|
||||
StrideD1s,
|
||||
StrideE1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
@@ -227,7 +227,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<1>>{});
|
||||
else
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
tile_distribution_encoding<
|
||||
sequence<NWarps>,
|
||||
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
|
||||
sequence<K_Thread / AK1, K_Lane, AK1 / APackedSize>>,
|
||||
|
||||
@@ -392,8 +392,4 @@ struct BlockReduce2D
|
||||
InDataType reduce_init;
|
||||
};
|
||||
|
||||
// deduction guide
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D<T>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -40,7 +40,7 @@ struct BlockSoftmax2D
|
||||
#endif
|
||||
|
||||
// compute row max
|
||||
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
|
||||
auto reduce_row_max = BlockReduce2D<decltype(x)>{x, -numeric<DataType>::infinity()};
|
||||
#if _BLOCK_SOFTMAX_USE_UNPACK2
|
||||
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
|
||||
#else
|
||||
|
||||
143
script/tools/ck-build
Executable file
143
script/tools/ck-build
Executable file
@@ -0,0 +1,143 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Build - Build Composable Kernel targets in Docker
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Build - Build Composable Kernel targets in Docker
|
||||
|
||||
Usage: ck-build [options] [target...]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
--name <name> Specify container name
|
||||
--reconfigure Reconfigure CMake before building
|
||||
-j <N> Parallel jobs (passed to ninja)
|
||||
--clean Clean before building
|
||||
|
||||
Arguments:
|
||||
target Target(s) to build (default: all)
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942)
|
||||
|
||||
Examples:
|
||||
ck-build # Build all targets
|
||||
ck-build test_amdgcn_mma # Build specific target
|
||||
ck-build test_amdgcn_mma test_gemm # Build multiple targets
|
||||
ck-build --reconfigure # Reconfigure CMake and build all
|
||||
ck-build --clean test_amdgcn_mma # Clean and build target
|
||||
ck-build -j 8 test_amdgcn_mma # Build with 8 parallel jobs
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
targets=()
|
||||
reconfigure=false
|
||||
clean=false
|
||||
parallel_jobs=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
--name)
|
||||
CONTAINER_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--reconfigure)
|
||||
reconfigure=true
|
||||
shift
|
||||
;;
|
||||
--clean)
|
||||
clean=true
|
||||
shift
|
||||
;;
|
||||
-j)
|
||||
parallel_jobs="-j $2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
targets+=("$1")
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Ensure container is running
|
||||
if ! container_is_running "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' not running. Starting..."
|
||||
"${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Configure CMake if needed or requested
|
||||
if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then
|
||||
echo "Detecting GPU target..."
|
||||
GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}")
|
||||
|
||||
if [ "$reconfigure" = true ]; then
|
||||
echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}"
|
||||
else
|
||||
echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}"
|
||||
fi
|
||||
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace || exit 1
|
||||
rm -rf /workspace/build
|
||||
mkdir /workspace/build
|
||||
cd /workspace/build || exit 1
|
||||
cmake .. -GNinja \
|
||||
-DGPU_TARGETS=${GPU_TARGET_DETECTED} \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-DBUILD_TESTING=ON 2>&1 | tail -30
|
||||
"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Clean if requested
|
||||
if [ "$clean" = true ]; then
|
||||
echo "Cleaning build directory..."
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace/build || exit 1
|
||||
ninja clean
|
||||
"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Build targets
|
||||
if [ ${#targets[@]} -eq 0 ]; then
|
||||
echo "Building all configured targets..."
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace/build || exit 1
|
||||
ninja ${parallel_jobs} 2>&1
|
||||
"
|
||||
else
|
||||
echo "Building targets: ${targets[*]}"
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace/build || exit 1
|
||||
ninja ${parallel_jobs} ${targets[*]} 2>&1
|
||||
"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Build complete ✓"
|
||||
113
script/tools/ck-clean
Executable file
113
script/tools/ck-clean
Executable file
@@ -0,0 +1,113 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Clean - Clean build artifacts in Docker container
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Clean - Clean build artifacts in Docker container
|
||||
|
||||
Usage: ck-clean [options]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
--name <name> Specify container name
|
||||
--all Remove entire build directory
|
||||
-f, --force Force without confirmation
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
|
||||
Examples:
|
||||
ck-clean # Clean build artifacts (ninja clean)
|
||||
ck-clean --all # Remove entire build directory
|
||||
ck-clean --force --all # Remove build directory without confirmation
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
remove_all=false
|
||||
force=false
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
--name)
|
||||
CONTAINER_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--all)
|
||||
remove_all=true
|
||||
shift
|
||||
;;
|
||||
-f|--force)
|
||||
force=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
show_help
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check if container is running
|
||||
if ! container_is_running "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' not running"
|
||||
echo "Start with: ck-start"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if build directory exists
|
||||
if ! docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then
|
||||
echo "Build directory does not exist"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$remove_all" = true ]; then
|
||||
# Remove entire build directory
|
||||
if [ "$force" = false ]; then
|
||||
read -p "Remove entire build directory? (y/N) " -n 1 -r
|
||||
echo ""
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo "Cancelled"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Removing build directory..."
|
||||
docker exec "${CONTAINER_NAME}" bash -c "rm -rf /workspace/build"
|
||||
echo "Build directory removed ✓"
|
||||
else
|
||||
# Clean with ninja
|
||||
if ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then
|
||||
echo "Build not configured (build.ninja not found)"
|
||||
echo "Use --all to remove build directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Cleaning build artifacts..."
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace/build || exit 1
|
||||
ninja clean
|
||||
"
|
||||
echo "Build artifacts cleaned ✓"
|
||||
fi
|
||||
111
script/tools/ck-exec
Executable file
111
script/tools/ck-exec
Executable file
@@ -0,0 +1,111 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Exec - Execute arbitrary commands in Docker container
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Exec - Execute arbitrary commands in Docker container
|
||||
|
||||
Usage: ck-exec [options] <command> [args...]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
--name <name> Specify container name
|
||||
-w <dir> Working directory (default: /workspace)
|
||||
-i, --interactive Interactive mode (allocate TTY)
|
||||
|
||||
Arguments:
|
||||
command Command to execute (required)
|
||||
args Arguments to the command
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
|
||||
Examples:
|
||||
ck-exec rocm-smi # Run rocm-smi
|
||||
ck-exec rocminfo # Run rocminfo
|
||||
ck-exec ls -la build/bin # List build binaries
|
||||
ck-exec -w /workspace/build ninja -t commands # Run ninja commands
|
||||
ck-exec --interactive python3 # Interactive Python session
|
||||
|
||||
Common Commands:
|
||||
ck-exec rocm-smi # Check GPU status
|
||||
ck-exec rocminfo \| grep gfx # Check GPU architecture
|
||||
ck-exec hipcc --version # Check HIP compiler version
|
||||
ck-exec cmake --version # Check CMake version
|
||||
ck-exec ninja -C build -t targets # List all build targets
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
workdir="/workspace"
|
||||
interactive=false
|
||||
command_args=()
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
--name)
|
||||
CONTAINER_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
-w)
|
||||
workdir="$2"
|
||||
shift 2
|
||||
;;
|
||||
-i|--interactive)
|
||||
interactive=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
command_args+=("$1")
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate command
|
||||
if [ ${#command_args[@]} -eq 0 ]; then
|
||||
echo "Error: command required"
|
||||
echo ""
|
||||
show_help
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure container is running
|
||||
if ! container_is_running "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' not running. Starting..."
|
||||
"${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Build command string
|
||||
cmd_string=""
|
||||
for arg in "${command_args[@]}"; do
|
||||
cmd_string="${cmd_string} $(printf '%q' "$arg")"
|
||||
done
|
||||
|
||||
# Execute command
|
||||
if [ "$interactive" = true ]; then
|
||||
docker exec -it -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}"
|
||||
else
|
||||
docker exec -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}"
|
||||
fi
|
||||
134
script/tools/ck-logs
Executable file
134
script/tools/ck-logs
Executable file
@@ -0,0 +1,134 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Logs - View container logs and build output
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Logs - View container logs and build output
|
||||
|
||||
Usage: ck-logs [options] [container_name]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
--name <name> Specify container name
|
||||
-f, --follow Follow log output
|
||||
-n, --tail <N> Show last N lines (default: 100)
|
||||
--cmake Show CMake configuration log
|
||||
--build Show last build log
|
||||
|
||||
Arguments:
|
||||
container_name Optional container name (default: ck_<username>_<branch>)
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
|
||||
Examples:
|
||||
ck-logs # Show last 100 lines of container logs
|
||||
ck-logs -f # Follow container logs
|
||||
ck-logs -n 500 # Show last 500 lines
|
||||
ck-logs --cmake # Show CMake configuration
|
||||
ck-logs --build # Show build log
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
follow=false
|
||||
tail_lines=100
|
||||
show_cmake=false
|
||||
show_build=false
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
--name)
|
||||
CONTAINER_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
-f|--follow)
|
||||
follow=true
|
||||
shift
|
||||
;;
|
||||
-n|--tail)
|
||||
tail_lines="$2"
|
||||
shift 2
|
||||
;;
|
||||
--cmake)
|
||||
show_cmake=true
|
||||
shift
|
||||
;;
|
||||
--build)
|
||||
show_build=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
CONTAINER_NAME="$1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check if container exists
|
||||
if ! container_exists "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' does not exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Show CMake log
|
||||
if [ "$show_cmake" = true ]; then
|
||||
echo "CMake Configuration Log:"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if docker exec "${CONTAINER_NAME}" test -f /workspace/build/CMakeCache.txt 2>/dev/null; then
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace/build
|
||||
echo 'GPU_TARGETS:' \$(grep 'GPU_TARGETS:' CMakeCache.txt | cut -d'=' -f2)
|
||||
echo 'CMAKE_BUILD_TYPE:' \$(grep 'CMAKE_BUILD_TYPE:' CMakeCache.txt | cut -d'=' -f2)
|
||||
echo 'CMAKE_CXX_COMPILER:' \$(grep 'CMAKE_CXX_COMPILER:' CMakeCache.txt | cut -d'=' -f2)
|
||||
echo 'BUILD_TESTING:' \$(grep 'BUILD_TESTING:' CMakeCache.txt | cut -d'=' -f2)
|
||||
"
|
||||
else
|
||||
echo "CMake not configured yet"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Show build log (last build output)
|
||||
if [ "$show_build" = true ]; then
|
||||
echo "Last Build Log:"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if docker exec "${CONTAINER_NAME}" test -f /workspace/build/.ninja_log 2>/dev/null; then
|
||||
docker exec "${CONTAINER_NAME}" bash -c "tail -50 /workspace/build/.ninja_log"
|
||||
else
|
||||
echo "No build log found"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Show container logs
|
||||
echo "Container Logs (${CONTAINER_NAME}):"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if [ "$follow" = true ]; then
|
||||
docker logs -f "${CONTAINER_NAME}"
|
||||
else
|
||||
docker logs --tail "${tail_lines}" "${CONTAINER_NAME}"
|
||||
fi
|
||||
84
script/tools/ck-shell
Executable file
84
script/tools/ck-shell
Executable file
@@ -0,0 +1,84 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Shell - Open interactive shell in Docker container
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Shell - Open interactive shell in Docker container
|
||||
|
||||
Usage: ck-shell [options] [container_name]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
--name <name> Specify container name
|
||||
-c <command> Execute command instead of interactive shell
|
||||
|
||||
Arguments:
|
||||
container_name Optional container name (default: ck_<username>_<branch>)
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
|
||||
Examples:
|
||||
ck-shell # Open interactive shell
|
||||
ck-shell my_container # Open shell in specific container
|
||||
ck-shell -c "rocm-smi" # Execute single command
|
||||
ck-shell -c "cd build && ls bin" # Execute command in build directory
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
command=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
--name)
|
||||
CONTAINER_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
-c)
|
||||
command="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
CONTAINER_NAME="$1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Ensure container is running
|
||||
if ! container_is_running "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' not running. Starting..."
|
||||
"${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Execute command or open shell
|
||||
if [ -n "$command" ]; then
|
||||
echo "Executing: ${command}"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
docker exec "${CONTAINER_NAME}" bash -c "${command}"
|
||||
else
|
||||
echo "Opening shell in '${CONTAINER_NAME}' (type 'exit' to leave)..."
|
||||
docker exec -it "${CONTAINER_NAME}" bash
|
||||
fi
|
||||
103
script/tools/ck-start
Executable file
103
script/tools/ck-start
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Start - Start Docker container for Composable Kernel testing
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Start - Start Docker container for Composable Kernel testing
|
||||
|
||||
Usage: ck-start [options] [container_name]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
--image <image> Specify Docker image (overrides CK_DOCKER_IMAGE)
|
||||
|
||||
Arguments:
|
||||
container_name Optional container name (default: ck_<username>_<branch>)
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
CK_DOCKER_IMAGE - Override Docker image (default: rocm/composable_kernel:ck_ub24.04_rocm7.0.1)
|
||||
|
||||
Examples:
|
||||
ck-start # Start container with default name
|
||||
ck-start my_ck_container # Start container with custom name
|
||||
ck-start --image rocm/composable_kernel:latest
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
--image)
|
||||
export CK_DOCKER_IMAGE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
CONTAINER_NAME="$1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Get Docker image
|
||||
DOCKER_IMAGE=$(get_docker_image)
|
||||
|
||||
# Check if container exists and is running
|
||||
if container_exists "${CONTAINER_NAME}"; then
|
||||
if container_is_running "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' is already running"
|
||||
docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd"
|
||||
exit 0
|
||||
else
|
||||
echo "Starting existing container '${CONTAINER_NAME}'..."
|
||||
docker start "${CONTAINER_NAME}"
|
||||
echo "Container started"
|
||||
docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
# Create new container
|
||||
echo "Creating new Docker container '${CONTAINER_NAME}'..."
|
||||
echo "Docker image: ${DOCKER_IMAGE}"
|
||||
echo "Project root: ${PROJECT_ROOT}"
|
||||
echo ""
|
||||
|
||||
docker run -d \
|
||||
--name "${CONTAINER_NAME}" \
|
||||
--device=/dev/kfd --device=/dev/dri \
|
||||
--security-opt seccomp=unconfined \
|
||||
--group-add video \
|
||||
-v "${PROJECT_ROOT}":/workspace \
|
||||
-w /workspace \
|
||||
"${DOCKER_IMAGE}" \
|
||||
tail -f /dev/null
|
||||
|
||||
echo ""
|
||||
echo "Container '${CONTAINER_NAME}' started successfully"
|
||||
docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd"
|
||||
|
||||
# Show GPU info
|
||||
echo ""
|
||||
echo "GPU Information:"
|
||||
docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -5 || echo 'No GPU detected'"
|
||||
153
script/tools/ck-status
Executable file
153
script/tools/ck-status
Executable file
@@ -0,0 +1,153 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Status - Check container status and information
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Status - Check container status and information
|
||||
|
||||
Usage: ck-status [options] [container_name]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
--name <name> Specify container name
|
||||
--all Show all CK containers
|
||||
-v, --verbose Show detailed information
|
||||
|
||||
Arguments:
|
||||
container_name Optional container name (default: ck_<username>_<branch>)
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
|
||||
Examples:
|
||||
ck-status # Check default container status
|
||||
ck-status my_container # Check specific container
|
||||
ck-status --all # Show all CK containers
|
||||
ck-status -v # Show detailed information
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
show_all=false
|
||||
verbose=false
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
--name)
|
||||
CONTAINER_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--all)
|
||||
show_all=true
|
||||
shift
|
||||
;;
|
||||
-v|--verbose)
|
||||
verbose=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
CONTAINER_NAME="$1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
DOCKER_IMAGE=$(get_docker_image)
|
||||
|
||||
# Show all containers
|
||||
if [ "$show_all" = true ]; then
|
||||
echo "Composable Kernel Docker Containers:"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
username=$(get_username)
|
||||
containers=$(docker ps -a --filter "name=ck_${username}_" --format "table {{.Names}}\t{{.Status}}\t{{.CreatedAt}}" 2>/dev/null || echo "")
|
||||
|
||||
if [ -z "$containers" ] || [ "$containers" = "NAMES STATUS CREATED AT" ]; then
|
||||
echo "No CK containers found for user '${username}'"
|
||||
else
|
||||
echo "$containers"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check specific container status
|
||||
echo "Container: ${CONTAINER_NAME}"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
if container_is_running "${CONTAINER_NAME}"; then
|
||||
echo "Status: RUNNING ✓"
|
||||
echo ""
|
||||
docker ps --filter "name=^${CONTAINER_NAME}$" --format "table {{.Names}}\t{{.Status}}\t{{.Image}}"
|
||||
|
||||
if [ "$verbose" = true ]; then
|
||||
echo ""
|
||||
echo "Container Details:"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
docker inspect "${CONTAINER_NAME}" --format '
|
||||
Image: {{.Config.Image}}
|
||||
Created: {{.Created}}
|
||||
Platform: {{.Platform}}
|
||||
Mounts: {{range .Mounts}}
|
||||
- {{.Source}} -> {{.Destination}}{{end}}
|
||||
'
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "GPU Information:"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -10 || echo 'No GPU detected'"
|
||||
|
||||
if [ "$verbose" = true ]; then
|
||||
echo ""
|
||||
echo "GPU Target:"
|
||||
gpu_target=$(detect_gpu_target "${CONTAINER_NAME}")
|
||||
echo " ${gpu_target}"
|
||||
|
||||
echo ""
|
||||
echo "Build Status:"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
if docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then
|
||||
if docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then
|
||||
echo " CMake configured ✓"
|
||||
echo " Build directory: /workspace/build"
|
||||
|
||||
# Count built test binaries
|
||||
bin_count=$(docker exec "${CONTAINER_NAME}" bash -c "ls -1 /workspace/build/bin 2>/dev/null | wc -l" || echo "0")
|
||||
echo " Test binaries: ${bin_count}"
|
||||
else
|
||||
echo " CMake not configured"
|
||||
fi
|
||||
else
|
||||
echo " Build directory not found"
|
||||
fi
|
||||
fi
|
||||
|
||||
elif container_exists "${CONTAINER_NAME}"; then
|
||||
echo "Status: STOPPED"
|
||||
echo ""
|
||||
echo "Start with: ck-start"
|
||||
else
|
||||
echo "Status: DOES NOT EXIST"
|
||||
echo ""
|
||||
echo "Create with: ck-start"
|
||||
fi
|
||||
141
script/tools/ck-stop
Executable file
141
script/tools/ck-stop
Executable file
@@ -0,0 +1,141 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Stop - Stop and remove Docker container
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Stop - Stop and remove Docker container
|
||||
|
||||
Usage: ck-stop [options] [container_name]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
-f, --force Force stop without confirmation
|
||||
--all Stop all CK containers for this user
|
||||
|
||||
Arguments:
|
||||
container_name Optional container name (default: ck_<username>_<branch>)
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
|
||||
Examples:
|
||||
ck-stop # Stop default container
|
||||
ck-stop my_ck_container # Stop specific container
|
||||
ck-stop --all # Stop all user's CK containers
|
||||
ck-stop --force # Stop without confirmation
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
force=false
|
||||
stop_all=false
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
-f|--force)
|
||||
force=true
|
||||
shift
|
||||
;;
|
||||
--all)
|
||||
stop_all=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
CONTAINER_NAME="$1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Function to stop a single container
|
||||
stop_container() {
|
||||
local name="$1"
|
||||
|
||||
if ! container_exists "${name}"; then
|
||||
echo "Container '${name}' does not exist"
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "Stopping and removing container '${name}'..."
|
||||
docker stop "${name}" 2>/dev/null || true
|
||||
docker rm "${name}" 2>/dev/null || true
|
||||
echo "Container '${name}' stopped and removed"
|
||||
}
|
||||
|
||||
# Stop all user containers
|
||||
if [ "$stop_all" = true ]; then
|
||||
username=$(get_username)
|
||||
containers=$(docker ps -a --filter "name=ck_${username}_" --format '{{.Names}}')
|
||||
|
||||
if [ -z "$containers" ]; then
|
||||
echo "No CK containers found for user '${username}'"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Found CK containers for user '${username}':"
|
||||
echo "$containers"
|
||||
echo ""
|
||||
|
||||
if [ "$force" = false ]; then
|
||||
read -p "Stop and remove all these containers? (y/N) " -n 1 -r
|
||||
echo ""
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo "Cancelled"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
while IFS= read -r container; do
|
||||
stop_container "$container"
|
||||
done <<< "$containers"
|
||||
|
||||
echo ""
|
||||
echo "All containers stopped and removed"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Stop single container
|
||||
if ! container_exists "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' does not exist"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Show container info
|
||||
if container_is_running "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' is currently running"
|
||||
else
|
||||
echo "Container '${CONTAINER_NAME}' exists but is stopped"
|
||||
fi
|
||||
|
||||
# Confirm if not forced
|
||||
if [ "$force" = false ]; then
|
||||
read -p "Stop and remove container '${CONTAINER_NAME}'? (y/N) " -n 1 -r
|
||||
echo ""
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo "Cancelled"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
stop_container "${CONTAINER_NAME}"
|
||||
166
script/tools/ck-test
Executable file
166
script/tools/ck-test
Executable file
@@ -0,0 +1,166 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Test - Build and test Composable Kernel in Docker
|
||||
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
# Find script directory and load common utilities
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
source "${SCRIPT_DIR}/common.sh"
|
||||
|
||||
# Initialize configuration
|
||||
PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}")
|
||||
CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}")
|
||||
|
||||
# Help message
|
||||
show_help() {
|
||||
cat << EOF
|
||||
CK Test - Build and test Composable Kernel in Docker
|
||||
|
||||
Usage: ck-test [options] <test_name> [test_options]
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message
|
||||
--name <name> Specify container name
|
||||
--reconfigure Reconfigure CMake before building
|
||||
--no-build Skip building, run test directly
|
||||
|
||||
Arguments:
|
||||
test_name Name of test executable (required)
|
||||
test_options Additional options passed to test (e.g., --gtest_filter=*)
|
||||
|
||||
Environment:
|
||||
CK_CONTAINER_NAME - Override default container name
|
||||
GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942)
|
||||
|
||||
Examples:
|
||||
ck-test test_amdgcn_mma
|
||||
ck-test test_amdgcn_mma --gtest_filter=*Fp16*
|
||||
ck-test --name my_container test_amdgcn_mma
|
||||
ck-test --reconfigure test_amdgcn_mma
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
test_name=""
|
||||
reconfigure=false
|
||||
no_build=false
|
||||
test_options=()
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
--name)
|
||||
CONTAINER_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--reconfigure)
|
||||
reconfigure=true
|
||||
shift
|
||||
;;
|
||||
--no-build)
|
||||
no_build=true
|
||||
shift
|
||||
;;
|
||||
--gtest_*|--help)
|
||||
test_options+=("$1")
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
if [ -z "$test_name" ]; then
|
||||
test_name="$1"
|
||||
else
|
||||
test_options+=("$1")
|
||||
fi
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate test name
|
||||
if [ -z "$test_name" ]; then
|
||||
echo "Error: test_name required"
|
||||
echo ""
|
||||
show_help
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure container is running
|
||||
if ! container_is_running "${CONTAINER_NAME}"; then
|
||||
echo "Container '${CONTAINER_NAME}' not running. Starting..."
|
||||
"${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Configure CMake if needed or requested
|
||||
if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then
|
||||
echo "Detecting GPU target..."
|
||||
GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}")
|
||||
|
||||
if [ "$reconfigure" = true ]; then
|
||||
echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}"
|
||||
else
|
||||
echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}"
|
||||
fi
|
||||
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace || exit 1
|
||||
rm -rf /workspace/build
|
||||
mkdir /workspace/build
|
||||
cd /workspace/build || exit 1
|
||||
cmake .. -GNinja \
|
||||
-DGPU_TARGETS=${GPU_TARGET_DETECTED} \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-DBUILD_TESTING=ON 2>&1 | tail -30
|
||||
"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Build test if needed (unless --no-build is specified)
|
||||
if [ "$no_build" = false ]; then
|
||||
if ! docker exec "${CONTAINER_NAME}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then
|
||||
echo "Building ${test_name}..."
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace/build || exit 1
|
||||
ninja ${test_name} 2>&1
|
||||
"
|
||||
echo ""
|
||||
else
|
||||
echo "Test executable found, rebuilding to ensure latest version..."
|
||||
docker exec "${CONTAINER_NAME}" bash -c "
|
||||
cd /workspace/build || exit 1
|
||||
ninja ${test_name} 2>&1
|
||||
"
|
||||
echo ""
|
||||
fi
|
||||
fi
|
||||
|
||||
# Run test
|
||||
echo "Running: ${test_name} ${test_options[*]}"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
|
||||
# Build the command with proper quoting
|
||||
cmd="cd /workspace/build && ./bin/${test_name}"
|
||||
for opt in "${test_options[@]}"; do
|
||||
cmd="${cmd} $(printf '%q' "$opt")"
|
||||
done
|
||||
|
||||
docker exec "${CONTAINER_NAME}" bash -c "${cmd}"
|
||||
exit_code=$?
|
||||
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
if [ $exit_code -eq 0 ]; then
|
||||
echo "Test completed successfully"
|
||||
else
|
||||
echo "Test failed with exit code: ${exit_code}"
|
||||
fi
|
||||
|
||||
exit $exit_code
|
||||
@@ -13,13 +13,8 @@ class TestCkTileGemmPipelineCompV3
|
||||
static constexpr bool check_data_type()
|
||||
{
|
||||
using Base = TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV3<T>>;
|
||||
if constexpr(std::is_same_v<typename Base::ADataType, F8> &&
|
||||
std::is_same_v<typename Base::BDataType, BF8>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Base::BLayout, Row> &&
|
||||
std::is_same_v<typename Base::BDataType, I4>)
|
||||
if constexpr(std::is_same_v<typename Base::BLayout, Row> &&
|
||||
std::is_same_v<typename Base::BDataType, I4>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
@@ -180,7 +180,7 @@ using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
@@ -190,7 +190,7 @@ using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
@@ -200,7 +200,7 @@ using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>
|
||||
|
||||
Reference in New Issue
Block a user