mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-14 20:27:42 +00:00
[rocm-libraries] ROCm/rocm-libraries#4295 (commit fa2cfc8)
[CK_TILE] Refactor `UniversalGemm::MakeA/B/C/DBlockViews` to allow caller to pass desciptors directly (#4295) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Currently `UniversalGemmKernel::MakeA/B/C/DBlockViews` directly create tensor views from strides and sizes. This refactors the descriptor creation out and add overloaded definitions, allowing descriptors to be created separately by the caller instead of passing explicit strides, with no functional changes. This will enable further refactoring of `RunGemm` to do likewise, enabling derived kernels like BatchedContractionKernel to avoid creating separate versions (PR [#3457](https://github.com/ROCm/composable_kernel/pull/3457)). ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion Since the logic within the MakeXBlockviews chains together operations on tuples, and thus the descriptors are also passed as such, adding a template parameter for the type of the input tuple was the simplest option to enable the overload without too much verbiage. However, for `MakeCBlockView` this adds a complications as the templated definitions are prone to overlap. This for now is avoided by just moving the arguments around for the descriptor version, which avoids the collision. It's not a great solution, so feel free to suggest a better one.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4c626aeaa6
commit
cd12e8e31f
@@ -689,39 +689,150 @@ struct UniversalGemmKernel
|
||||
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
|
||||
}
|
||||
|
||||
template <typename ALayout>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m)
|
||||
MakeDefaultATensorDescriptor(const index_t M, const index_t stride, const index_t k_size)
|
||||
{
|
||||
// Step 1: Create tensor views for A tensors (from MakeGemmTensorViews)
|
||||
const auto& as_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, k_size),
|
||||
make_tuple(stride, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(k_size, M),
|
||||
make_tuple(stride, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLayout>
|
||||
CK_TILE_DEVICE static auto MakeDefaultBTensorDescriptor(const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride,
|
||||
const index_t k_size)
|
||||
{
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1),
|
||||
make_tuple(N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
return transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(k_size, N),
|
||||
make_tuple(stride, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1),
|
||||
make_tuple(N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
return transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_As[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = N * K / kFlatK;
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_As[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
return make_naive_tensor_descriptor(make_tuple(N, k_size),
|
||||
make_tuple(stride, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DLayout, index_t VectorSizeD>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDefaultDTensorDescriptor(const index_t M, const index_t N, const index_t stride)
|
||||
{
|
||||
if constexpr(std::is_same_v<DLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, M), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride)
|
||||
{
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N),
|
||||
make_tuple(stride, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(1, stride), number<1>{}, number<1>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AsTensorDesc>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const AsTensorDesc& as_desc,
|
||||
const index_t i_m)
|
||||
{
|
||||
// Step 1: Create tensor views
|
||||
const auto& as_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]), as_desc[i]);
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
// Step 2: Create padded views
|
||||
const auto& as_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
@@ -742,7 +853,7 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
// Step 3: Create tile windows
|
||||
const auto& as_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
@@ -767,101 +878,38 @@ struct UniversalGemmKernel
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m)
|
||||
{
|
||||
// Step 1: Create tensor descriptors for A tensors
|
||||
const auto& as_tensor_desc = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
return MakeDefaultATensorDescriptor<AiLayout>(kargs.M, kargs.stride_As[i], k_size);
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
return MakeABlockWindows(as_ptr, as_tensor_desc, i_m);
|
||||
}
|
||||
|
||||
template <typename BsTensorDesc>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const BsTensorDesc& bs_desc,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor views for B tensors (from MakeGemmTensorViews)
|
||||
// Step 1: Create tensor views
|
||||
const auto& bs_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB =
|
||||
std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(k_size, kargs.N),
|
||||
make_tuple(kargs.stride_Bs[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB =
|
||||
std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(kargs.N, k_size),
|
||||
make_tuple(kargs.stride_Bs[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BiDataType*>(bs_ptr[i]), bs_desc[i]);
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
// Step 2: Create padded views
|
||||
const auto& bs_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
@@ -882,7 +930,7 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
// Step 3: Create tile windows
|
||||
const auto& bs_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
@@ -918,38 +966,39 @@ struct UniversalGemmKernel
|
||||
return bs_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_n)
|
||||
{
|
||||
const auto& bs_tensor_desc = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
return MakeDefaultBTensorDescriptor<BiLayout>(
|
||||
kargs.N, kargs.K, kargs.stride_Bs[i], k_size);
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
return MakeBBlockWindows(bs_ptr, bs_tensor_desc, i_n);
|
||||
}
|
||||
|
||||
template <typename DsTensorDesc>
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const DsTensorDesc& ds_desc,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor views for D tensors (from MakeGemmTensorViews)
|
||||
// Step 1: Create tensor views
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]), ds_desc[i]);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 2: Create padded views (from MakeGemmPadViews)
|
||||
// Step 2: Create padded views
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
@@ -970,7 +1019,7 @@ struct UniversalGemmKernel
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 3: Create tile windows (from MakeGemmTileWindows)
|
||||
// Step 3: Create tile windows
|
||||
const auto& ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
@@ -994,35 +1043,34 @@ struct UniversalGemmKernel
|
||||
return ds_block_window;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
const auto& ds_tensor_desc = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
return MakeDefaultDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD(i)>(
|
||||
kargs.M, kargs.N, kargs.stride_Ds[i]);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 2: Create padded view (from MakeGemmPadViews)
|
||||
return MakeDBlockWindows(ds_ptr, ds_tensor_desc, i_m, i_n);
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename ETensorDesc>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindows(
|
||||
EDataType* e_ptr,
|
||||
const index_t i_m,
|
||||
const index_t i_n,
|
||||
const ETensorDesc& e_desc) // Argument order differs from A,B,D to disambiguate overloads
|
||||
{
|
||||
// Step 1: Create tensor view for E/C tensor
|
||||
const auto& e_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(e_ptr, e_desc);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& e_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -1040,7 +1088,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// Step 3: Create tile window (from MakeGemmTileWindows)
|
||||
// Step 3: Create tile window
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
@@ -1049,6 +1097,17 @@ struct UniversalGemmKernel
|
||||
return e_block_window;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
|
||||
const auto& e_tensor_desc = MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E);
|
||||
return MakeCBlockWindows<DstInMemOp>(e_ptr, i_m, i_n, e_tensor_desc);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user