Merge commit '47e2ed838e3547bba1b48d3f559f20f46fd07b87' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-20 02:43:03 +00:00
parent ca48bf3b98
commit 809c1ead72
183 changed files with 987 additions and 863 deletions

View File

@@ -143,16 +143,24 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
}
}();
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
const index_t kFlatN = kargs.N / kNWarpTile;
const auto& b_flat_tensor_view = [&]() {
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
"wrong! vector size for B tensor");
auto&& naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
auto&& desc = transform_tensor_descriptor(
naive_desc,
make_tuple(make_pass_through_transform(kFlatN),
make_merge_transform_v3_division_mod(
make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
}();
const auto& ds_tensor_view = generate_tuple(