WMMA grouped conv fwd large tensor extra flavors (#3582)

* Additional flavors for WMMA conv fwd large tensor

- added F16/BF16 clamp operation
- added F16/BF16 bias_clamp operation
- small modification to the device code to accomodate extra tensors

* changed strategy to handle GemmArgs array

* Adding generic instance

* Added generic instance to clamp and bias_clamp ops
This commit is contained in:
Wojciech Laskowski
2026-01-23 12:19:51 +01:00
committed by GitHub
parent 7b3db1a878
commit 81ee19bd2c
27 changed files with 1007 additions and 171 deletions

View File

@@ -617,32 +617,32 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m);
const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n);
GemmArgs new_args{};
new_args.a_ptrs_ = p_as_grid;
new_args.b_ptrs_ = p_bs_grid;
new_args.ds_ptrs_ = p_ds_grid;
new_args.e_ptr_ = p_e_grid;
new_args.a_element_op_ = a_element_op_;
new_args.b_element_op_ = b_element_op_;
new_args.cde_element_op_ = cde_element_op_;
new_args.M_ = gemm_m;
new_args.N_ = gemm_n;
new_args.a_grid_desc_ = a_grid_desc;
new_args.b_grid_desc_ = b_grid_desc;
new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
const auto ds_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, m_block, n_block);
new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
const auto e_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, m_block, n_block);
new_args.BlockStart_ = BlockStart;
new_args.BlockEnd_ = BlockEnd;
gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args;
gemm_desc_kernel_args_.Emplace(
valid_gemms_count_,
GemmArgs{.a_ptrs_ = p_as_grid,
.b_ptrs_ = p_bs_grid,
.ds_ptrs_ = p_ds_grid,
.e_ptr_ = p_e_grid,
.a_element_op_ = a_element_op_,
.b_element_op_ = b_element_op_,
.cde_element_op_ = cde_element_op_,
.M_ = gemm_m,
.N_ = gemm_n,
.a_grid_desc_ = a_grid_desc,
.b_grid_desc_ = b_grid_desc,
.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
ds_desc_mblock_mperblock_nblock_nperblock,
.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
e_desc_mblock_mperblock_nblock_nperblock,
.BlockStart_ = BlockStart,
.BlockEnd_ = BlockEnd});
valid_gemms_count_++;
}
@@ -789,11 +789,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
});
}
}
void Print() const
@@ -807,12 +810,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
<< ", is_split_valid=" << std::boolalpha << is_split_valid_
<< std::noboolalpha << ", grid_size=" << grid_size_ << std::endl;
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << " Ds[" << i.value
<< "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i)
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i)
<< std::endl;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << " Ds[" << i.value << "] group stride="
<< compute_ptr_offset_of_groups_.BatchStrideDs_.At(i)
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_.At(i)
<< std::endl;
});
}
std::cout << "===== GEMM splits =====" << std::endl;
for(index_t i = 0; i < valid_gemms_count_; ++i)
@@ -836,11 +842,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: "
<< gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl;
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
std::cout << " D" << d_idx.value << " descriptor: "
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx)
<< std::endl;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
std::cout << " D" << d_idx.value << " descriptor: "
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_.At(d_idx)
<< std::endl;
});
}
}
}

View File

@@ -6,6 +6,8 @@
#include "functional2.hpp"
#include "sequence.hpp"
#include <type_traits>
#include <cassert>
namespace ck {
@@ -27,6 +29,15 @@ struct Array
__host__ __device__ constexpr TData& operator()(index_t i) { return At(i); }
template <typename... Args>
__host__ constexpr auto Emplace(index_t i, Args&&... args)
-> std::enable_if_t<std::is_nothrow_constructible_v<TData, Args&&...>>
{
assert(i >= 0 && i < NSize);
mData[i].~TData();
new(mData + i) TData(ck::forward<Args>(args)...);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{