forked from jiuyuan/InfiniTensor
fix: infershape should change some op private params
This commit is contained in:
parent
5375d529db
commit
c38d26bee8
|
@ -27,7 +27,7 @@ class AllGatherObj : public OperatorObj {
|
|||
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return world_size; }
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class AllReduceBaseObj : public OperatorObj {
|
|||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class BroadcastObj : public OperatorObj {
|
|||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
|
|
|
@ -108,6 +108,7 @@ class ConvBaseObj : public OperatorObj {
|
|||
int getPw() const { return pw; }
|
||||
int getSh() const { return sh; }
|
||||
int getSw() const { return sw; }
|
||||
void setNCHWFRS(Tensor input, Tensor weight);
|
||||
auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); }
|
||||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||
int getChannelPerGroup() const {
|
||||
|
|
|
@ -623,13 +623,9 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "root"
|
||||
),
|
||||
0,
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "root"),
|
||||
0,
|
||||
),
|
||||
)
|
||||
elif node.op_type == "Expand":
|
||||
shape = _parse_data(data[node.input[1]])
|
||||
|
|
|
@ -329,7 +329,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
[pads_data],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_allReduceSum(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
@ -349,7 +349,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
graph = make_graph([allReduceProd], "allReduceProd", [input], [output])
|
||||
model = make_model(graph)
|
||||
from_onnx(model, backend.cpu_runtime())
|
||||
|
||||
|
||||
def test_allReduceMin(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
@ -379,14 +379,12 @@ class TestStringMethods(unittest.TestCase):
|
|||
graph = make_graph([allReduceAvg], "allReduceAvg", [input], [output])
|
||||
model = make_model(graph)
|
||||
from_onnx(model, backend.cpu_runtime())
|
||||
|
||||
|
||||
def test_split(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
split = make_node(
|
||||
"Split", ["input"], ["output"], name="split", axis=0
|
||||
)
|
||||
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
|
||||
make_and_import_model(make_graph([split], "split", [input], []))
|
||||
|
||||
|
||||
def test_allBroadcast(self):
|
||||
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
|
||||
|
@ -508,6 +506,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
array1 = np.array(tensor1, copy=False)
|
||||
self.assertTrue(np.array_equal(array1, np_array))
|
||||
|
||||
|
||||
class TestDynamicTensor(unittest.TestCase):
|
||||
def test_dynamic_tensor(self):
|
||||
filename = r"resnet18-v2-7.onnx"
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#include "core/graph.h"
|
||||
#include "operators/reshape.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "operators/where.h"
|
||||
#include <numeric>
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -507,6 +508,9 @@ void GraphHandlerObj::change_shape(const vector<int> &shape, int tensorId) {
|
|||
IT_ASSERT(tensor != nullptr);
|
||||
IT_ASSERT(shape.size() != 0);
|
||||
tensor->setShape(shape);
|
||||
size_t size = std::accumulate(shape.begin(), shape.end(), 1,
|
||||
[](auto acc, auto x) { return acc * x; });
|
||||
tensor->setSize(size);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -22,13 +22,16 @@ string G2BMMObj::toString() const {
|
|||
|
||||
optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) {
|
||||
auto A = inputs[0], B = inputs[1];
|
||||
b = A->getDims()[0];
|
||||
m = A->getDims()[1];
|
||||
k = A->getDims()[2];
|
||||
|
||||
IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
|
||||
IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
|
||||
IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
|
||||
IT_ASSERT(A->getDims()[2] == B->getDims()[2]);
|
||||
IT_ASSERT(width >= 0);
|
||||
int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1);
|
||||
int n(2 * width + 1);
|
||||
return {{{b, m, n}}};
|
||||
}
|
||||
|
||||
|
|
|
@ -23,13 +23,16 @@ string GBMMObj::toString() const {
|
|||
|
||||
optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) {
|
||||
auto A = inputs[0], B = inputs[1];
|
||||
b = A->getDims()[0];
|
||||
m = A->getDims()[1];
|
||||
w = (A->getDims()[2] - 1) / 2;
|
||||
n = B->getDims()[2];
|
||||
|
||||
IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
|
||||
IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
|
||||
IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
|
||||
IT_ASSERT(A->getDims()[2] % 2 != 0);
|
||||
int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]);
|
||||
return {{{b, m, k}}};
|
||||
return {{{b, m, n}}};
|
||||
}
|
||||
|
||||
vector<int> GBMMObj::getWorkloadVector() const {
|
||||
|
|
|
@ -10,8 +10,7 @@ AllGatherObj::AllGatherObj(GraphObj *graph, Tensor input,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
AllGatherObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> AllGatherObj::inferShape(const TensorVec &inputs) {
|
||||
Shape input_shape = inputs[0]->getDims();
|
||||
vector<Shape> output_shapes(getWorldSize(), input_shape);
|
||||
return output_shapes;
|
||||
|
|
|
@ -35,6 +35,16 @@ string ConvBaseObj::toString() const {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
void ConvBaseObj::setNCHWFRS(Tensor input, Tensor weight) {
|
||||
n = input->getDims()[0];
|
||||
c = input->getDims()[1];
|
||||
h = input->getDims()[2];
|
||||
w = input->getDims()[3];
|
||||
f = weight->getDims()[0];
|
||||
r = weight->getDims()[2];
|
||||
s = weight->getDims()[3];
|
||||
}
|
||||
|
||||
vector<int> ConvBaseObj::getWorkloadVector() const {
|
||||
return {type.underlying(), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
}
|
||||
|
@ -84,12 +94,7 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
|
||||
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) {
|
||||
const auto &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto h = input->getDims()[2];
|
||||
auto w = input->getDims()[3];
|
||||
auto f = weight->getDims()[0];
|
||||
auto r = weight->getDims()[2];
|
||||
auto s = weight->getDims()[3];
|
||||
setNCHWFRS(input, weight);
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
|
@ -143,13 +148,13 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
|||
optional<vector<Shape>>
|
||||
ConvTransposed2dObj::inferShape(const TensorVec &inputs) {
|
||||
const Tensor &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto f = input->getDims()[1];
|
||||
auto h = input->getDims()[2];
|
||||
auto w = input->getDims()[3];
|
||||
auto c = weight->getDims()[1];
|
||||
auto r = weight->getDims()[2];
|
||||
auto s = weight->getDims()[3];
|
||||
n = input->getDims()[0];
|
||||
f = input->getDims()[1];
|
||||
h = input->getDims()[2];
|
||||
w = input->getDims()[3];
|
||||
c = weight->getDims()[1];
|
||||
r = weight->getDims()[2];
|
||||
s = weight->getDims()[3];
|
||||
IT_ASSERT(f == weight->getDims()[0]);
|
||||
|
||||
int on = n, oc = c * group;
|
||||
|
@ -221,12 +226,7 @@ ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX,
|
|||
optional<vector<Shape>>
|
||||
ConvBackwardFilterObj::inferShape(const TensorVec &inputs) {
|
||||
const auto &inputX = inputs[0], &diffY = inputs[1];
|
||||
auto n = inputX->getDims()[0];
|
||||
auto h = inputX->getDims()[2];
|
||||
auto w = inputX->getDims()[3];
|
||||
auto f = diffY->getDims()[0];
|
||||
auto r = diffY->getDims()[2];
|
||||
auto s = diffY->getDims()[3];
|
||||
setNCHWFRS(inputX, diffY);
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
|
@ -282,15 +282,14 @@ ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input,
|
|||
optional<vector<Shape>>
|
||||
ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) {
|
||||
const Tensor &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto f = input->getDims()[3];
|
||||
auto h = input->getDims()[1];
|
||||
auto w = input->getDims()[2];
|
||||
auto c = weight->getDims()[3];
|
||||
auto r = weight->getDims()[1];
|
||||
auto s = weight->getDims()[2];
|
||||
if (f != weight->getDims()[0])
|
||||
return {};
|
||||
n = input->getDims()[0];
|
||||
f = input->getDims()[3];
|
||||
h = input->getDims()[1];
|
||||
w = input->getDims()[2];
|
||||
c = weight->getDims()[3];
|
||||
r = weight->getDims()[1];
|
||||
s = weight->getDims()[2];
|
||||
IT_ASSERT(f == weight->getDims()[0]);
|
||||
|
||||
int on = n, oc = c * group;
|
||||
int oh = 0, ow = 0;
|
||||
|
|
|
@ -16,8 +16,10 @@ PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
|
|||
|
||||
optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) {
|
||||
const auto &input = inputs[0];
|
||||
auto h = input->getDims()[input->getRank() - 2],
|
||||
w = input->getDims()[input->getRank() - 1];
|
||||
n = input->getDims()[0];
|
||||
c = input->getDims()[1];
|
||||
h = input->getDims()[input->getRank() - 2];
|
||||
w = input->getDims()[input->getRank() - 1];
|
||||
int oh = (h - (kh - sh) + ph * 2) / sh;
|
||||
int ow = (w - (kw - sw) + pw * 2) / sw;
|
||||
auto ret = input->getDims();
|
||||
|
|
Loading…
Reference in New Issue