forked from jiuyuan/InfiniTensor
parent
2a23669394
commit
bb9b62b169
|
@ -95,17 +95,17 @@ static int tensor_dtype(Tensor t) {
|
||||||
|
|
||||||
static int concat_axis_of(Operator op) {
|
static int concat_axis_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||||
return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
|
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||||
}
|
}
|
||||||
|
|
||||||
static int gather_axis_of(Operator op) {
|
static int gather_axis_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::Gather);
|
IT_ASSERT(op->getOpType() == OpType::Gather);
|
||||||
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
|
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Shape reshape_shape_of(Operator op) {
|
static Shape reshape_shape_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::Reshape);
|
IT_ASSERT(op->getOpType() == OpType::Reshape);
|
||||||
return reinterpret_cast<const ReshapeObj *>(op.get())->getShape();
|
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_graph_builder(py::module &m) {
|
void init_graph_builder(py::module &m) {
|
||||||
|
|
Loading…
Reference in New Issue