[CK_TILE] Remove Old CK Tile Stream-K Artifacts (#3202)

* Remove old CK Tile Stream-K implementation

The original CK Stream-K implementation was based on old CK's Stream-K
block to C tile map. However, this implementation did not align with the
original Stream-K paper. Thus, we implemented a new tile partitioner and
associated Stream-K kernel, which was placed in the reboot namespace.

Now that the new Stream-K implementation is ready, this change removes
all artifacts of the old implementation. Specifically, the following
changes were made:
- Removes old Stream-K tile partitioner from CK Tile
- Removes the reboot namespace such that the new implementation resides
  in the ck_tile namespace only.
- Adds tests for bf8 and fp8 using the new implementation
- Removes tests for the old implementation
- Remove the v2 suffix from the new CK Tile Tile Partitioner
derived classes.
- Updates Stream-K Kernel ops file to use /** commenting style.

* Remove v2 from tile partitioner validation function names
This commit is contained in:
Emily Martins
2025-11-20 09:32:32 -07:00
committed by GitHub
parent 5adaa201ed
commit 2e4b8a8fc4
153 changed files with 880 additions and 3829 deletions

View File

@@ -71,16 +71,16 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy)
{
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy};
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy};
std::tuple<float, ck_tile::index_t> ave_time_and_batch;

View File

@@ -16,7 +16,7 @@ template <typename GemmConfig,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -28,7 +28,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs&
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -77,7 +77,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs&
memory_operation.value,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);