mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Separate tensor descriptor creation from the tensor view creation
This adds utility functions to construct default tensor descriptors for
A, B, C and D tensors and refactors the Make{A,B,C,D}BlockWindows to call make_tensor_view
using the utility functions instead of directly calling
make_naive_tensor_view, allowing for further refactors later.
This commit is contained in:
@@ -661,6 +661,136 @@ struct UniversalGemmKernel
|
||||
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
|
||||
}
|
||||
|
||||
template <typename ALayout>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDefaultATensorDescriptor(const index_t M, const index_t stride, const index_t k_size)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, k_size),
|
||||
make_tuple(stride, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(k_size, M),
|
||||
make_tuple(stride, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLayout>
|
||||
CK_TILE_DEVICE static auto MakeDefaultBTensorDescriptor(const index_t N,
|
||||
const index_t K,
|
||||
const index_t stride,
|
||||
const index_t k_size)
|
||||
{
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1),
|
||||
make_tuple(N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
return transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(k_size, N),
|
||||
make_tuple(stride, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1),
|
||||
make_tuple(N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
return transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = N * K / kFlatK;
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N, k_size),
|
||||
make_tuple(stride, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DLayout, index_t VectorSizeD>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDefaultDTensorDescriptor(const index_t M, const index_t N, const index_t stride)
|
||||
{
|
||||
if constexpr(std::is_same_v<DLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, M), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride)
|
||||
{
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), // arguments not matching with flatmm.
|
||||
make_tuple(stride, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(1, stride), number<1>{}, number<1>{});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const KernelArgs& kargs,
|
||||
@@ -672,24 +802,10 @@ struct UniversalGemmKernel
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_As[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_As[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
MakeDefaultATensorDescriptor<AiLayout>(kargs.M, kargs.stride_As[i], k_size));
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
@@ -749,87 +865,10 @@ struct UniversalGemmKernel
|
||||
[&](auto i) {
|
||||
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
|
||||
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB =
|
||||
std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(k_size, kargs.N),
|
||||
make_tuple(kargs.stride_Bs[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB =
|
||||
std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(kargs.N, k_size),
|
||||
make_tuple(kargs.stride_Bs[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BiDataType*>(bs_ptr[i]),
|
||||
MakeDefaultBTensorDescriptor<BiLayout>(
|
||||
kargs.N, kargs.K, kargs.stride_Bs[i], k_size));
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
@@ -900,24 +939,10 @@ struct UniversalGemmKernel
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
MakeDefaultDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD(i)>(
|
||||
kargs.M, kargs.N, kargs.stride_Ds[i]));
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
@@ -973,26 +998,8 @@ struct UniversalGemmKernel
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
const auto& e_tensor_view = make_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr, MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E));
|
||||
|
||||
// Step 2: Create padded view (from MakeGemmPadViews)
|
||||
const auto& e_pad_view = [&]() {
|
||||
|
||||
Reference in New Issue
Block a user