rename example

This commit is contained in:
Feng Shijie
2025-08-17 17:51:18 +00:00
parent 7899fb4a8d
commit 599e1f5b32
7 changed files with 31 additions and 29 deletions

View File

@@ -84,7 +84,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
static constexpr index_t GetVectorSizeB() { return 32; /* fixed for fp4 shuffle layout*/ }
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr bool kPadM = Problem::kPadM;
@@ -470,11 +470,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
// __builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto GetADramTileDistribution()
{
return PipelinePolicy::template MakeADramTileDistribution<Problem>();
}
template <typename ADramBlockWindowTmp,
typename AElementFunction,
typename BFlatBlockWindowTmp,
typename DequantBFlatWindow>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const DequantBFlatWindow& scale_b_flat_window,
@@ -524,19 +529,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
// MakeF16xF4_ALDS_TileDistribution<Problem>(); auto A_Lds_TileDist =
// PipelinePolicy::template MakeADramTileDistribution<Problem>(); auto A_Lds_Stride = 8;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
A_Lds_TileDist);
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
@@ -549,7 +546,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
A_XDL_TileDist);
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),