Merge commit 'e339101e9c9961fe1bc8305d5c316b39d1980d3e' into develop

This commit is contained in:
assistant-librarian[bot]
2026-01-04 12:20:15 +00:00
parent d8734393c7
commit c5abac7854
68 changed files with 4198 additions and 4298 deletions

View File

@@ -558,21 +558,19 @@ struct FlatmmKernel
return DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t block_idx_m)
{
// Step 1: Create tensor view
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
@@ -581,25 +579,81 @@ struct FlatmmKernel
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
// Step 2: Create padded view
const auto& a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
// Step 3: Create tile window
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, block_idx_m});
}
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
const KernelArgs& kargs,
const index_t block_idx_n)
{
// Step 1: Create tensor view
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
// Step 2: No padding needed for b_flat
// Step 3: Create tile window
return make_tile_window(
b_flat_tensor_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -625,7 +679,56 @@ struct FlatmmKernel
},
number<NumDTensor>{});
// TODO: enable vector write for C in ColMajor
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// Step 3: Create tile windows
return generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{block_idx_n, block_idx_m});
}
},
number<NumDTensor>{});
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor view
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
@@ -647,98 +750,8 @@ struct FlatmmKernel
}
}();
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
: 1; // per-token scale
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
: 1; // per-channel scale
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
"only support per-tensor or per-row scaling");
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
"only support per-tensor or per-column scaling");
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: splitk_batch_offset.splitted_k /
(ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(ScaleGranularityKB == 0
? 1
: (splitk_batch_offset.splitted_k /
(ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
return make_tuple(a_tensor_view,
b_flat_tensor_view,
ds_tensor_view,
e_tensor_view,
scale_m_view,
scale_n_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
const auto& b_flat_tensor_view = views.at(I1);
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
const auto& d_tensor_view = views.at(I2);
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// TODO vector write in for C in ColMajor
// Step 2: Create padded view
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
@@ -755,93 +768,72 @@ struct FlatmmKernel
}
}();
return make_tuple(a_pad_view,
b_flat_tensor_view,
ds_pad_view,
e_pad_view,
views.at(number<4>{}),
views.at(number<5>{}));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_flat_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
},
number<NumDTensor>{});
auto e_block_window = make_tile_window(
// Step 3: Create tile window
return make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
{block_idx_m, block_idx_n});
}
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m)
{
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
auto scale_m_window = make_tile_window(views.at(number<4>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number < ScaleGranularityKA == 0
? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
{i_m, 0});
auto scale_n_window = make_tile_window(views.at(number<5>{}),
make_tuple(number < ScaleGranularityKB == 0
? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
: 1; // per-token scale
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_m_window,
scale_n_window);
// Step 1: Create tensor view
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: (splitk_batch_offset.splitted_k / ScaleGranularityKA)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
// Step 2: Create tile window
return make_tile_window(scale_m_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number < ScaleGranularityKA == 0
? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
{block_idx_m, 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_n)
{
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
: 1; // per-channel scale
// Step 1: Create tensor view
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
// Step 2: Create tile window
return make_tile_window(scale_n_view,
make_tuple(number < ScaleGranularityKB == 0
? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
number<TilePartitioner::NPerBlock>{}),
{0, block_idx_n});
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -857,45 +849,74 @@ struct FlatmmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
auto scale_m_window = gemm_tile_windows.at(number<4>{});
auto scale_n_window = gemm_tile_windows.at(number<5>{});
// Run Epilogue Pipeline
// Run Epilogue Pipeline with k_batch dispatching
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
}
}
@@ -924,8 +945,7 @@ struct FlatmmKernel
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);