mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
g ad d add pading
This commit is contained in:
@@ -292,9 +292,14 @@ struct FusedMoeGemmGlKernel
|
||||
number<Pipeline::kAlignmentG>{},
|
||||
number<1>{});
|
||||
|
||||
const auto g_window_ = make_tile_window(
|
||||
const auto g_view_1_ = pad_tensor_view(
|
||||
g_view_,
|
||||
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
|
||||
sequence<PadIntermediateSize, PadHiddenSize>{});
|
||||
|
||||
const auto g_window_ = make_tile_window(
|
||||
g_view_1_,
|
||||
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
|
||||
{idx_n0, 0});
|
||||
|
||||
return g_window_;
|
||||
@@ -328,9 +333,14 @@ struct FusedMoeGemmGlKernel
|
||||
number<Pipeline::kAlignmentD>{},
|
||||
number<1>{});
|
||||
|
||||
const auto d_window_ = make_tile_window(
|
||||
const auto d_view_1_ = pad_tensor_view(
|
||||
d_view_,
|
||||
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
|
||||
sequence<PadHiddenSize, PadIntermediateSize>{});
|
||||
|
||||
const auto d_window_ = make_tile_window(
|
||||
d_view_1_,
|
||||
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
|
||||
{0, idx_n0});
|
||||
return d_window_;
|
||||
}();
|
||||
|
||||
@@ -391,7 +391,7 @@ struct FusedMoeGemmKernel
|
||||
number<Pipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
// gather is here
|
||||
// scatter is here
|
||||
auto o_scatter_view_ = transform_tensor_view(
|
||||
o_view_,
|
||||
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
|
||||
|
||||
@@ -71,9 +71,7 @@ struct FusedMoeGemmPipeline_General
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
|
||||
{
|
||||
// matrix a or tokens smem
|
||||
constexpr index_t smem_mat_a =
|
||||
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
|
||||
return smem_mat_a;
|
||||
return Policy::template GetSmemSize_A<Problem>();
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
@@ -131,11 +129,8 @@ struct FusedMoeGemmPipeline_General
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t hidden_size,
|
||||
index_t /*intermediate_size*/,
|
||||
CWindow& c_window_)
|
||||
CWindow& /*c_window_*/)
|
||||
{
|
||||
ignore = c_window_;
|
||||
ignore = hidden_size;
|
||||
ignore = w_window_;
|
||||
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
|
||||
CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
|
||||
smem_0 + GetSmemSizeA() / sizeof(ADataType));
|
||||
@@ -234,11 +229,11 @@ struct FusedMoeGemmPipeline_General
|
||||
#if 0
|
||||
PrintMem(y_pre, "Y_pre", 0);
|
||||
#endif
|
||||
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
block_sync_lds();
|
||||
store_tile(c_window_, y_pre);
|
||||
}
|
||||
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// {
|
||||
// block_sync_lds();
|
||||
// store_tile(c_window_, y_pre);
|
||||
// }
|
||||
// save to lds
|
||||
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
|
||||
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
|
||||
|
||||
@@ -312,12 +312,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
// make_tuple(number<Block_M>{}, number<Block_K>{}),
|
||||
// make_tuple(number<Block_K>{}, number<1>{}),
|
||||
// number<8>{},
|
||||
// number<1>{});
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user