mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
rename example
This commit is contained in:
@@ -134,11 +134,12 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<32>{},
|
||||
number<1>{});
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
@@ -368,12 +369,18 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
auto a_block_window_with_distr =
|
||||
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
a_block_window.get_window_origin(),
|
||||
FlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
|
||||
|
||||
@@ -84,7 +84,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
|
||||
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
|
||||
static constexpr index_t GetVectorSizeB() { return 32; /* fixed for fp4 shuffle layout*/ }
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
@@ -470,11 +470,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetADramTileDistribution()
|
||||
{
|
||||
return PipelinePolicy::template MakeADramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename DequantBFlatWindow>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const DequantBFlatWindow& scale_b_flat_window,
|
||||
@@ -524,19 +529,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
// MakeF16xF4_ALDS_TileDistribution<Problem>(); auto A_Lds_TileDist =
|
||||
// PipelinePolicy::template MakeADramTileDistribution<Problem>(); auto A_Lds_Stride = 8;
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
auto a_copy_lds_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
A_Lds_TileDist);
|
||||
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
@@ -549,7 +546,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
A_XDL_TileDist);
|
||||
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
|
||||
Reference in New Issue
Block a user