forked from jiuyuan/InfiniTensor
parent
2a23669394
commit
bb9b62b169
|
@ -95,17 +95,17 @@ static int tensor_dtype(Tensor t) {
|
|||
|
||||
static int concat_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||
return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
|
||||
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||
}
|
||||
|
||||
static int gather_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Gather);
|
||||
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
|
||||
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
|
||||
}
|
||||
|
||||
static Shape reshape_shape_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Reshape);
|
||||
return reinterpret_cast<const ReshapeObj *>(op.get())->getShape();
|
||||
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
|
|
Loading…
Reference in New Issue