mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
moe gemm1 scaleready
This commit is contained in:
@@ -71,7 +71,7 @@ struct MulABScale
|
||||
(void)d2; // for gate, no d2 needed
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
const float x0_f = c;
|
||||
const float x0_f = c * d1 * d0;
|
||||
// const float x0_f = c;
|
||||
e = ck::type_convert<EDataType>(x0_f);
|
||||
}
|
||||
@@ -286,9 +286,9 @@ int main(int argc, char* argv[])
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 3});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3});
|
||||
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{1, 3});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
@@ -304,6 +304,9 @@ int main(int argc, char* argv[])
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
d0_t_n.savetxt("d0_t_n.txt", "int");
|
||||
d1_e_n.savetxt("d1_e_n.txt", "int");
|
||||
d2_m_n.savetxt("d2_m_n.txt", "int");
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
|
||||
@@ -325,8 +328,6 @@ int main(int argc, char* argv[])
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr auto I0 = ck::Number<0>{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
@@ -352,7 +353,7 @@ int main(int argc, char* argv[])
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{I0, I0, I0},
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
@@ -406,9 +407,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
const int t = sorted_token_ids(m);
|
||||
const int e = expert_ids(m / sorted_tile_size);
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(m, n), d2_m_n(m, n));
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(e, n), d2_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1401,7 +1401,7 @@ struct GridwiseMoeGemmGather
|
||||
if (i.value == 1)
|
||||
{
|
||||
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
|
||||
// if ( threadIdx.x ==0)
|
||||
// if ( threadIdx.x % 16 ==0)
|
||||
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
|
||||
}
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -1448,10 +1448,11 @@ struct GridwiseMoeGemmGather
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
const float *p_sorted_weights = p_ds_grid[I2];
|
||||
const float *p_sorted_weights = p_ds_grid[I0];
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
scatter_offsets(m0) = 0;
|
||||
scatter_weights(m0) = p_sorted_weights[c_token_pos + m0];
|
||||
scatter_weights(m0) = p_sorted_weights[(c_token_pos + m0) * problem.StrideDs[0]];
|
||||
// if(threadIdx.x % 16 == 0)
|
||||
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
|
||||
});
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
|
||||
@@ -176,10 +176,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
src_coords_[i]);
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if (i.value == ScatterWeightIdx)
|
||||
if (i.value == ScatterWeightIdx)
|
||||
{
|
||||
static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec");
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1, "scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number<iScatter>{}));
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights_(Number<iScatter>{}); });
|
||||
}
|
||||
@@ -189,11 +191,15 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
const auto tmp =
|
||||
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
|
||||
}
|
||||
else
|
||||
{
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d vn\n", blockIdx.y, threadIdx.x, i.value);
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user