Merge commit '2e4b8a8fc455a14ad5cf89f7f750060ff20c40bb' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-20 17:12:11 +00:00
parent b2e58aec1a
commit 8b93b58bcd
153 changed files with 880 additions and 3829 deletions

View File

@@ -71,16 +71,16 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy)
{
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy};
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy};
std::tuple<float, ck_tile::index_t> ave_time_and_batch;

View File

@@ -16,7 +16,7 @@ template <typename GemmConfig,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -28,7 +28,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs&
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -77,7 +77,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs&
memory_operation.value,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);