mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
WMMA grouped conv fwd large tensor extra flavors (#3582)
* Additional flavors for WMMA conv fwd large tensor - added F16/BF16 clamp operation - added F16/BF16 bias_clamp operation - small modification to the device code to accomodate extra tensors * changed strategy to handle GemmArgs array * Adding generic instance * Added generic instance to clamp and bias_clamp ops
This commit is contained in:
committed by
GitHub
parent
7b3db1a878
commit
81ee19bd2c
@@ -617,32 +617,32 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
|
||||
const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m);
|
||||
const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n);
|
||||
|
||||
GemmArgs new_args{};
|
||||
new_args.a_ptrs_ = p_as_grid;
|
||||
new_args.b_ptrs_ = p_bs_grid;
|
||||
new_args.ds_ptrs_ = p_ds_grid;
|
||||
new_args.e_ptr_ = p_e_grid;
|
||||
|
||||
new_args.a_element_op_ = a_element_op_;
|
||||
new_args.b_element_op_ = b_element_op_;
|
||||
new_args.cde_element_op_ = cde_element_op_;
|
||||
|
||||
new_args.M_ = gemm_m;
|
||||
new_args.N_ = gemm_n;
|
||||
|
||||
new_args.a_grid_desc_ = a_grid_desc;
|
||||
new_args.b_grid_desc_ = b_grid_desc;
|
||||
new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
const auto ds_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, m_block, n_block);
|
||||
new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
const auto e_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, m_block, n_block);
|
||||
|
||||
new_args.BlockStart_ = BlockStart;
|
||||
new_args.BlockEnd_ = BlockEnd;
|
||||
|
||||
gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args;
|
||||
gemm_desc_kernel_args_.Emplace(
|
||||
valid_gemms_count_,
|
||||
GemmArgs{.a_ptrs_ = p_as_grid,
|
||||
.b_ptrs_ = p_bs_grid,
|
||||
.ds_ptrs_ = p_ds_grid,
|
||||
.e_ptr_ = p_e_grid,
|
||||
.a_element_op_ = a_element_op_,
|
||||
.b_element_op_ = b_element_op_,
|
||||
.cde_element_op_ = cde_element_op_,
|
||||
.M_ = gemm_m,
|
||||
.N_ = gemm_n,
|
||||
.a_grid_desc_ = a_grid_desc,
|
||||
.b_grid_desc_ = b_grid_desc,
|
||||
.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
ds_desc_mblock_mperblock_nblock_nperblock,
|
||||
.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
e_desc_mblock_mperblock_nblock_nperblock,
|
||||
.BlockStart_ = BlockStart,
|
||||
.BlockEnd_ = BlockEnd});
|
||||
|
||||
valid_gemms_count_++;
|
||||
}
|
||||
@@ -789,11 +789,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
});
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -807,12 +810,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
|
||||
<< ", is_split_valid=" << std::boolalpha << is_split_valid_
|
||||
<< std::noboolalpha << ", grid_size=" << grid_size_ << std::endl;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << " Ds[" << i.value
|
||||
<< "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i)
|
||||
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i)
|
||||
<< std::endl;
|
||||
});
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << " Ds[" << i.value << "] group stride="
|
||||
<< compute_ptr_offset_of_groups_.BatchStrideDs_.At(i)
|
||||
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_.At(i)
|
||||
<< std::endl;
|
||||
});
|
||||
}
|
||||
|
||||
std::cout << "===== GEMM splits =====" << std::endl;
|
||||
for(index_t i = 0; i < valid_gemms_count_; ++i)
|
||||
@@ -836,11 +842,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
|
||||
std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: "
|
||||
<< gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
|
||||
std::cout << " D" << d_idx.value << " descriptor: "
|
||||
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx)
|
||||
<< std::endl;
|
||||
});
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
|
||||
std::cout << " D" << d_idx.value << " descriptor: "
|
||||
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_.At(d_idx)
|
||||
<< std::endl;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
#include "functional2.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include <type_traits>
|
||||
#include <cassert>
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -27,6 +29,15 @@ struct Array
|
||||
|
||||
__host__ __device__ constexpr TData& operator()(index_t i) { return At(i); }
|
||||
|
||||
template <typename... Args>
|
||||
__host__ constexpr auto Emplace(index_t i, Args&&... args)
|
||||
-> std::enable_if_t<std::is_nothrow_constructible_v<TData, Args&&...>>
|
||||
{
|
||||
assert(i >= 0 && i < NSize);
|
||||
mData[i].~TData();
|
||||
new(mData + i) TData(ck::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto operator=(const T& a)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user