mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Refactor the design of DeviceGemmMultipleDMultipleR_Xdl_CShuffle (#378)
This commit is contained in:
@@ -3,9 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -3,15 +3,27 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// FIXME: DeviceGemmReduce type need to well define the problem
|
||||
// GEMM:
|
||||
// input : A[AK0, M, AK1]
|
||||
// input : B[AK0, N, AK1]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// output : R0[M], R1[M], ...
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
|
||||
// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DELayout,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -192,7 +193,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
@@ -207,95 +211,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
@@ -310,92 +229,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
|
||||
@@ -413,47 +247,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(e_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return e_grid_desc_mraw_nraw;
|
||||
}
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// assume D is packed tensor
|
||||
@@ -482,10 +276,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
|
||||
using RGridDesc_M = decltype(MakeRGridDescriptor_M(1));
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
|
||||
using RGridDesc_M = decltype(MakeRGridDescriptor_M(1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
|
||||
@@ -504,8 +298,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
ThreadReduceOperations,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
RsGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
EGridDesc_M_N,
|
||||
RGridDesc_M,
|
||||
NumGemmKPrefetchStage,
|
||||
@@ -542,6 +336,13 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
RThreadTransferDstScalarPerVector_MPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -567,12 +368,16 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
p_ds_grid_{}, // FIXME
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
p_rs_grid_{}, // FIXME
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
r_grid_desc_m_{DeviceOp::MakeRGridDescriptor_M(MRaw)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
rs_grid_desc_mblock_mperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -581,8 +386,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
qs_element_op_{qs_element_op},
|
||||
rs_element_op_{rs_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
e_grid_desc_m_n_,
|
||||
r_grid_desc_m_,
|
||||
block_2_etile_map_))
|
||||
@@ -624,6 +429,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
typename GridwiseGemm::RsGridPointer p_rs_grid_;
|
||||
|
||||
// tensor descriptors
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
RGridDesc_M r_grid_desc_m_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
StaticallyIndexedArray<
|
||||
@@ -631,16 +442,14 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
NumDTensor>
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
|
||||
// type from E
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
RGridDesc_M r_grid_desc_m_;
|
||||
StaticallyIndexedArray<typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, NumRTensor>
|
||||
rs_grid_desc_mblock_mperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -657,8 +466,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.r_grid_desc_m_,
|
||||
arg.block_2_etile_map_))
|
||||
@@ -750,8 +559,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.r_grid_desc_m_,
|
||||
arg.block_2_etile_map_);
|
||||
|
||||
Reference in New Issue
Block a user