fix: infershape should change some op private params

This commit is contained in:
zhangyunze 2023-09-05 11:30:22 +08:00
parent 5375d529db
commit c38d26bee8
13 changed files with 58 additions and 53 deletions

View File

@ -27,7 +27,7 @@ class AllGatherObj : public OperatorObj {
int numInputs() const override { return 1; }
int numOutputs() const override { return world_size; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;

View File

@ -33,7 +33,7 @@ class AllReduceBaseObj : public OperatorObj {
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
return {{inputs[0]->getDims()}};
};

View File

@ -26,7 +26,7 @@ class BroadcastObj : public OperatorObj {
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
return {{inputs[0]->getDims()}};
};

View File

@ -108,6 +108,7 @@ class ConvBaseObj : public OperatorObj {
int getPw() const { return pw; }
int getSh() const { return sh; }
int getSw() const { return sw; }
void setNCHWFRS(Tensor input, Tensor weight);
auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); }
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
int getChannelPerGroup() const {

View File

@ -623,13 +623,9 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
next(
(
attr.i
for attr in node.attribute
if attr.name == "root"
),
0,
),
(attr.i for attr in node.attribute if attr.name == "root"),
0,
),
)
elif node.op_type == "Expand":
shape = _parse_data(data[node.input[1]])

View File

@ -329,7 +329,7 @@ class TestStringMethods(unittest.TestCase):
[pads_data],
)
)
def test_allReduceSum(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
@ -349,7 +349,7 @@ class TestStringMethods(unittest.TestCase):
graph = make_graph([allReduceProd], "allReduceProd", [input], [output])
model = make_model(graph)
from_onnx(model, backend.cpu_runtime())
def test_allReduceMin(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
@ -379,14 +379,12 @@ class TestStringMethods(unittest.TestCase):
graph = make_graph([allReduceAvg], "allReduceAvg", [input], [output])
model = make_model(graph)
from_onnx(model, backend.cpu_runtime())
def test_split(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
split = make_node(
"Split", ["input"], ["output"], name="split", axis=0
)
split = make_node("Split", ["input"], ["output"], name="split", axis=0)
make_and_import_model(make_graph([split], "split", [input], []))
def test_allBroadcast(self):
input = make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 2, 4])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 2, 4])
@ -508,6 +506,7 @@ class TestStringMethods(unittest.TestCase):
array1 = np.array(tensor1, copy=False)
self.assertTrue(np.array_equal(array1, np_array))
class TestDynamicTensor(unittest.TestCase):
def test_dynamic_tensor(self):
filename = r"resnet18-v2-7.onnx"

View File

@ -1,5 +1,4 @@
#include "core/graph.h"
#include "operators/reshape.h"
#include <algorithm>
#include <numeric>
#include <queue>

View File

@ -19,6 +19,7 @@
#include "operators/transpose.h"
#include "operators/unary.h"
#include "operators/where.h"
#include <numeric>
namespace infini {
@ -507,6 +508,9 @@ void GraphHandlerObj::change_shape(const vector<int> &shape, int tensorId) {
IT_ASSERT(tensor != nullptr);
IT_ASSERT(shape.size() != 0);
tensor->setShape(shape);
size_t size = std::accumulate(shape.begin(), shape.end(), 1,
[](auto acc, auto x) { return acc * x; });
tensor->setSize(size);
}
} // namespace infini

View File

@ -22,13 +22,16 @@ string G2BMMObj::toString() const {
optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) {
auto A = inputs[0], B = inputs[1];
b = A->getDims()[0];
m = A->getDims()[1];
k = A->getDims()[2];
IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
IT_ASSERT(A->getDims()[2] == B->getDims()[2]);
IT_ASSERT(width >= 0);
int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1);
int n(2 * width + 1);
return {{{b, m, n}}};
}

View File

@ -23,13 +23,16 @@ string GBMMObj::toString() const {
optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) {
auto A = inputs[0], B = inputs[1];
b = A->getDims()[0];
m = A->getDims()[1];
w = (A->getDims()[2] - 1) / 2;
n = B->getDims()[2];
IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
IT_ASSERT(A->getDims()[2] % 2 != 0);
int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]);
return {{{b, m, k}}};
return {{{b, m, n}}};
}
vector<int> GBMMObj::getWorkloadVector() const {

View File

@ -10,8 +10,7 @@ AllGatherObj::AllGatherObj(GraphObj *graph, Tensor input,
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
AllGatherObj::inferShape(const TensorVec &inputs) const {
optional<vector<Shape>> AllGatherObj::inferShape(const TensorVec &inputs) {
Shape input_shape = inputs[0]->getDims();
vector<Shape> output_shapes(getWorldSize(), input_shape);
return output_shapes;

View File

@ -35,6 +35,16 @@ string ConvBaseObj::toString() const {
return os.str();
}
void ConvBaseObj::setNCHWFRS(Tensor input, Tensor weight) {
n = input->getDims()[0];
c = input->getDims()[1];
h = input->getDims()[2];
w = input->getDims()[3];
f = weight->getDims()[0];
r = weight->getDims()[2];
s = weight->getDims()[3];
}
vector<int> ConvBaseObj::getWorkloadVector() const {
return {type.underlying(), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
}
@ -84,12 +94,7 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) {
const auto &input = inputs[0], &weight = inputs[1];
auto n = input->getDims()[0];
auto h = input->getDims()[2];
auto w = input->getDims()[3];
auto f = weight->getDims()[0];
auto r = weight->getDims()[2];
auto s = weight->getDims()[3];
setNCHWFRS(input, weight);
int on = n, oc = f;
int oh = 0, ow = 0;
// For NCHW+FCRS layout, C of input is divisable by C of weight
@ -143,13 +148,13 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
optional<vector<Shape>>
ConvTransposed2dObj::inferShape(const TensorVec &inputs) {
const Tensor &input = inputs[0], &weight = inputs[1];
auto n = input->getDims()[0];
auto f = input->getDims()[1];
auto h = input->getDims()[2];
auto w = input->getDims()[3];
auto c = weight->getDims()[1];
auto r = weight->getDims()[2];
auto s = weight->getDims()[3];
n = input->getDims()[0];
f = input->getDims()[1];
h = input->getDims()[2];
w = input->getDims()[3];
c = weight->getDims()[1];
r = weight->getDims()[2];
s = weight->getDims()[3];
IT_ASSERT(f == weight->getDims()[0]);
int on = n, oc = c * group;
@ -221,12 +226,7 @@ ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX,
optional<vector<Shape>>
ConvBackwardFilterObj::inferShape(const TensorVec &inputs) {
const auto &inputX = inputs[0], &diffY = inputs[1];
auto n = inputX->getDims()[0];
auto h = inputX->getDims()[2];
auto w = inputX->getDims()[3];
auto f = diffY->getDims()[0];
auto r = diffY->getDims()[2];
auto s = diffY->getDims()[3];
setNCHWFRS(inputX, diffY);
int on = n, oc = f;
int oh = 0, ow = 0;
// For NCHW+FCRS layout, C of input is divisable by C of weight
@ -282,15 +282,14 @@ ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input,
optional<vector<Shape>>
ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) {
const Tensor &input = inputs[0], &weight = inputs[1];
auto n = input->getDims()[0];
auto f = input->getDims()[3];
auto h = input->getDims()[1];
auto w = input->getDims()[2];
auto c = weight->getDims()[3];
auto r = weight->getDims()[1];
auto s = weight->getDims()[2];
if (f != weight->getDims()[0])
return {};
n = input->getDims()[0];
f = input->getDims()[3];
h = input->getDims()[1];
w = input->getDims()[2];
c = weight->getDims()[3];
r = weight->getDims()[1];
s = weight->getDims()[2];
IT_ASSERT(f == weight->getDims()[0]);
int on = n, oc = c * group;
int oh = 0, ow = 0;

View File

@ -16,8 +16,10 @@ PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) {
const auto &input = inputs[0];
auto h = input->getDims()[input->getRank() - 2],
w = input->getDims()[input->getRank() - 1];
n = input->getDims()[0];
c = input->getDims()[1];
h = input->getDims()[input->getRank() - 2];
w = input->getDims()[input->getRank() - 1];
int oh = (h - (kh - sh) + ph * 2) / sh;
int ow = (w - (kw - sw) + pw * 2) / sw;
auto ret = input->getDims();