'feat: support dynamic input tensor

This commit is contained in:
zhangyunze 2023-08-31 15:40:00 +08:00
parent f60767a770
commit 02d02b1129
10 changed files with 103 additions and 4 deletions

View File

@ -53,6 +53,7 @@ class GraphObj : public Object {
const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
Tensor getTensorWithUid(int fuid) const;
/**
* Sort the nodes in topological order.
@ -64,6 +65,8 @@ class GraphObj : public Object {
void optimize();
void shape_infer();
void dataMalloc();
/**

View File

@ -58,7 +58,7 @@ class GraphHandlerObj {
Tensor clip(Tensor x, Tensor y, std::optional<float> min,
std::optional<float> max);
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
Tensor reshape(Tensor data, Tensor reshaped, Shape shape, Shape shape_t);
Tensor concat(TensorVec inputs, Tensor output, int dim);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
int num_outputs);
@ -73,6 +73,7 @@ class GraphHandlerObj {
Tensor cast(Tensor input, Tensor output, int to);
Tensor expand(Tensor input, Tensor output, Shape dims);
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
// std::vector<int> getDims(Tensor x) { return x->getDims(); }
Tensor allReduceSum(Tensor input, Tensor output);
Tensor allReduceProd(Tensor input, Tensor output);
@ -88,6 +89,9 @@ class GraphHandlerObj {
inline void optimize() { g->optimize(); }
inline void shape_infer() { g->shape_infer(); }
void change_shape(const vector<int> &shape, int tensorId);
//------ runtime
inline void data_malloc() { g->dataMalloc(); }

View File

@ -25,9 +25,11 @@ class TensorObj : public TensorBaseObj {
string toString() const override;
size_t size() const { return _size; }
void setSize(size_t size) { _size = size;}
size_t getBytes() const { return _size * dtype.getSize(); }
Shape getDims() const { return shape; }
void setShape(Shape shape_) { shape = shape_; }
size_t getRank() const { return shape.size(); }
Shape getStride() const;
size_t getOffset(const vector<int> &ds) const;

View File

@ -9,6 +9,7 @@ namespace infini {
*/
class ReshapeObj : public OperatorObj {
Shape dims;
Shape dims_t = {};
public:
/**
@ -18,8 +19,10 @@ class ReshapeObj : public OperatorObj {
* @param input The input tensor.
* @param output The output tensor.
* @param dims The shape of the output tensor.
* @param dims_t The origin data of change shape.
*/
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims, Shape dims_t);
OP_CLONE(ReshapeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
@ -29,6 +32,8 @@ class ReshapeObj : public OperatorObj {
int numOutputs() const override { return 1; }
inline Shape getShape() const { return dims; }
inline void setShape(Shape shape_) { dims = shape_; }
inline Shape getShape_t() const { return dims_t; }
private:
vector<int> getWorkloadVector() const override;

View File

@ -444,6 +444,7 @@ class OnnxStub:
dims = _search_shape(model, node.input[0])
size = reduce(lambda acc, x: acc * x, dims)
input_shape = _parse_data(data[node.input[1]])
origin_shape = input_shape.copy()
for i, x in enumerate(input_shape):
if x == 0:
input_shape[i] = dims[i]
@ -454,6 +455,7 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
input_shape,
origin_shape,
)
elif node.op_type == "Squeeze":
input_shape = _search_shape(model, node.input[0])
@ -1001,12 +1003,26 @@ class OnnxStub:
def optimize(self) -> None:
self.handler.optimize()
def set_input(self, inputShapes: List[int]) -> None:
for newInput, oldInput in zip(inputShapes, self.inputs):
oldTensor = self.inputs[oldInput];
self.handler.change_shape(newInput, oldTensor.fuid())
self.handler.shape_infer()
self.handler.data_malloc()
def tune(self) -> None:
self.handler.tune()
def run(self) -> None:
self.handler.run()
# def getShape(self, name: str) -> List[int]:
# if name in self.inputs:
# ans = self.handler.getDims(self.inputs[name])
# else:
# ans = self.handler.getDims(self.outputs[name])
# return ans
def get_perf_time(self) -> float:
self.handler.get_perf_time()

View File

@ -1,6 +1,8 @@
#include "core/graph.h"
#include <algorithm>
#include <queue>
#include <numeric>
#include "operators/reshape.h"
namespace infini {
@ -123,6 +125,58 @@ void GraphObj::optimize() {
}
}
Tensor GraphObj::getTensorWithUid(int fuid) const {
for (auto tensor : tensors) {
if (tensor->getFuid() == fuid) {
return tensor;
}
}
return nullptr;
}
void GraphObj::shape_infer() {
for (auto &op : ops) {
if (op->getOpType() == OpType::Reshape) {
auto reshape = dynamic_cast<ReshapeObj *>(op.get());
auto input = reshape->getInputs(0)->getDims();
auto size = reshape->getInputs(0)->size();
Shape ans = reshape->getShape_t();
int index = -1;
for (int i = 0; i < (int)ans.size(); ++i) {
if (ans[i] == 0) {
ans[i] = input[i];
}
if (ans[i] == -1) {
index = i;
}
}
if (index != -1) {
int temp = (int)size / (-std::accumulate(ans.begin(), ans.end(), 1, [](auto acc, auto x) {return acc * x;}));
ans[index] = temp;
}
reshape->setShape(ans);
}
auto ans = op->inferShape();
IT_ASSERT(ans.has_value());
std::cout<<"optype = "<<op->getOpType().toString()<<std::endl;
auto oldOutputs = op->getOutputs();
IT_ASSERT(ans.value().size() == oldOutputs.size());
for (int i = 0; i < (int)ans.value().size(); ++i) {
auto newShape = ans.value()[i];
std::cout<<vecToString(newShape)<<std::endl;
auto oldShape = oldOutputs[i]->getDims();
auto fuid = oldOutputs[i]->getFuid();
if (newShape != oldShape) {
auto tensor = this->getTensorWithUid(fuid);
tensor->setShape(newShape);
size_t size = std::accumulate(newShape.begin(), newShape.end(), 1, [](auto acc, auto x) {return acc * x;});
tensor->setSize(size);
std::cout<<"replace newShape over"<<std::endl;
}
}
}
}
void GraphObj::dataMalloc() {
// topological sorting first
IT_ASSERT(topo_sort() == true);

View File

@ -215,13 +215,13 @@ Tensor GraphHandlerObj::transpose(Tensor data, Tensor transposed, Shape perm) {
}
}
Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape, Shape shape_t) {
if (reshaped) {
g->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
std::move(shape));
std::move(shape), std::move(shape_t));
return reshaped;
} else {
return g->addOp<ReshapeObj>(std::move(data), reshaped, std::move(shape))
return g->addOp<ReshapeObj>(std::move(data), reshaped, std::move(shape), std::move(shape_t))
->getOutput();
}
}
@ -502,4 +502,9 @@ static DataType dtype_repr_convert(int dtype) {
}
}
void GraphHandlerObj::change_shape(const vector<int> &shape, int tensorId) {
auto tensor = g->getTensorWithUid(tensorId);
tensor->setShape(shape);
}
} // namespace infini

View File

@ -458,6 +458,9 @@ void init_graph_builder(py::module &m) {
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
.def("tune", &Handler::tune, policy::automatic)
.def("run", &Handler::run, policy::automatic)
.def("shape_infer", &Handler::shape_infer, policy::automatic)
.def("change_shape", &Handler::change_shape, policy::automatic)
// .def("getDims", &Handler::getDims, policy::automatic)
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
}

View File

@ -49,6 +49,8 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2));
Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2));
Shape ret = infer_broadcast(shapeA1, shapeB1);
auto m = *(this->getTransA() ? shapeA.rbegin() : shapeA.rbegin() + 1);
auto n = *(this->getTransB() ? shapeB.rbegin() + 1 : shapeB.rbegin());
ret.emplace_back(m);
ret.emplace_back(n);
return {{ret}};

View File

@ -7,6 +7,11 @@ ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
IT_ASSERT(checkValid(graph));
}
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims, Shape dims_t)
: OperatorObj(OpType::Reshape, {input}, {output}), dims(std::move(dims)), dims_t(std::move(dims_t)) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const {
size_t size = 1;
for (size_t i = 0; i < dims.size(); ++i) {