forked from jiuyuan/InfiniTensor
Add: connection among tensors and operators (#45)
* Add: refs_to_wrefs and wrefs_to_refs * Add: op and tensor connection * Add: inception-v3 block test * Refactor: addOperatorAndConnect Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
d1c913010f
commit
4e0040c8a0
|
@ -30,7 +30,7 @@ class GraphObj : public Object {
|
||||||
*/
|
*/
|
||||||
template <typename T, typename... Args> Ref<T> addOp(Args &&...args) {
|
template <typename T, typename... Args> Ref<T> addOp(Args &&...args) {
|
||||||
Ref<T> op = infini::make_ref<T>(this, std::forward<Args>(args)...);
|
Ref<T> op = infini::make_ref<T>(this, std::forward<Args>(args)...);
|
||||||
ops.push_back(op);
|
addOperatorAndConnect(op);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ class GraphObj : public Object {
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
Ref<T> addOpWithOutputs(Args &&...args) {
|
Ref<T> addOpWithOutputs(Args &&...args) {
|
||||||
Ref<T> op = infini::make_ref<T>(nullptr, std::forward<Args>(args)...);
|
Ref<T> op = infini::make_ref<T>(nullptr, std::forward<Args>(args)...);
|
||||||
ops.push_back(op);
|
addOperatorAndConnect(op);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,11 +55,10 @@ class GraphObj : public Object {
|
||||||
void dataMalloc();
|
void dataMalloc();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// TODO: updateConnection
|
|
||||||
/**
|
/**
|
||||||
* @brief Add reverse connections and Op relationship in ctor.
|
* @brief Add reverse connections and Op relationship in ctor.
|
||||||
*/
|
*/
|
||||||
void updateConnection();
|
void addOperatorAndConnect(const Operator &op);
|
||||||
|
|
||||||
// TODO: move to another class
|
// TODO: move to another class
|
||||||
// bool exportOnnx(const char *path);
|
// bool exportOnnx(const char *path);
|
||||||
|
|
|
@ -35,7 +35,7 @@ class Object {
|
||||||
virtual ~Object(){};
|
virtual ~Object(){};
|
||||||
virtual string toString() const = 0;
|
virtual string toString() const = 0;
|
||||||
void print() { std::cout << toString() << std::endl; }
|
void print() { std::cout << toString() << std::endl; }
|
||||||
Guid getGuid() const { return guid; }
|
GuidBaseType getGuid() const { return guid; }
|
||||||
};
|
};
|
||||||
|
|
||||||
inline std::ostream &operator<<(std::ostream &os, const Object &obj) {
|
inline std::ostream &operator<<(std::ostream &os, const Object &obj) {
|
||||||
|
|
|
@ -142,12 +142,11 @@ class OperatorObj : public Object {
|
||||||
OpType type;
|
OpType type;
|
||||||
TensorVec inputs;
|
TensorVec inputs;
|
||||||
TensorVec outputs;
|
TensorVec outputs;
|
||||||
// vector<WRef<Operator>> predecessors;
|
vector<WRef<OperatorObj>> predecessors;
|
||||||
// vector<WRef<Operator>> successors;
|
vector<WRef<OperatorObj>> successors;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
|
OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs);
|
||||||
: type(opType), inputs(inputs), outputs(outputs) {}
|
|
||||||
virtual optional<vector<Shape>>
|
virtual optional<vector<Shape>>
|
||||||
inferShape(const TensorVec &inputs) const = 0;
|
inferShape(const TensorVec &inputs) const = 0;
|
||||||
virtual vector<DataType> inferDataType(const TensorVec &inputs) const;
|
virtual vector<DataType> inferDataType(const TensorVec &inputs) const;
|
||||||
|
@ -177,9 +176,7 @@ class OperatorObj : public Object {
|
||||||
bool isMemBoundOp() const;
|
bool isMemBoundOp() const;
|
||||||
|
|
||||||
public: // getter and setter
|
public: // getter and setter
|
||||||
// TensorVec getInputs() { return inputs; }
|
|
||||||
const TensorVec &getInputs() const { return inputs; }
|
const TensorVec &getInputs() const { return inputs; }
|
||||||
// TensorVec getOutputs() { return outputs; }
|
|
||||||
const TensorVec &getOutputs() const { return outputs; }
|
const TensorVec &getOutputs() const { return outputs; }
|
||||||
Tensor getInputs(size_t i) const { return inputs.at(i); }
|
Tensor getInputs(size_t i) const { return inputs.at(i); }
|
||||||
Tensor getOutput() const {
|
Tensor getOutput() const {
|
||||||
|
@ -190,6 +187,10 @@ class OperatorObj : public Object {
|
||||||
IT_ASSERT(i < outputs.size(), "Index exceeded");
|
IT_ASSERT(i < outputs.size(), "Index exceeded");
|
||||||
return outputs.at(i);
|
return outputs.at(i);
|
||||||
}
|
}
|
||||||
|
void addPredecessors(const Operator &op) { predecessors.emplace_back(op); }
|
||||||
|
void addSuccessors(const Operator &op) { successors.emplace_back(op); }
|
||||||
|
OpVec getPredecessors() const { return wrefs_to_refs(predecessors); }
|
||||||
|
OpVec getSuccessors() const { return wrefs_to_refs(successors); }
|
||||||
OpType getOpType() const { return type; }
|
OpType getOpType() const { return type; }
|
||||||
// HACK: set correct data type
|
// HACK: set correct data type
|
||||||
DataType getDType() const { return getInputs(0)->getDType(); }
|
DataType getDType() const { return getInputs(0)->getDType(); }
|
||||||
|
|
|
@ -25,12 +25,19 @@ Ref<T> as(const Ref<U> &ref) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<WRef<T>> get_wref_vec(const std::vector<Ref<T>> &vec) {
|
std::vector<WRef<T>> refs_to_wrefs(const std::vector<Ref<T>> &refs) {
|
||||||
std::vector<WRef<T>> wref_vec;
|
std::vector<WRef<T>> wrefs;
|
||||||
wref_vec.reserve(vec.size());
|
for (const auto &ref : refs)
|
||||||
for (const auto &ref : vec)
|
wrefs.emplace_back(ref);
|
||||||
wref_vec.emplace_back(ref);
|
return wrefs;
|
||||||
return wref_vec;
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<Ref<T>> wrefs_to_refs(const std::vector<WRef<T>> &wrefs) {
|
||||||
|
std::vector<Ref<T>> refs;
|
||||||
|
for (const auto &wref : wrefs)
|
||||||
|
refs.emplace_back(wref);
|
||||||
|
return refs;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
|
@ -19,8 +19,8 @@ class TensorBaseObj : public Object {
|
||||||
int dim;
|
int dim;
|
||||||
|
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
vector<WRef<TensorBaseObj>> inputOf;
|
vector<WRef<OperatorObj>> inputOf;
|
||||||
WRef<TensorBaseObj> outputOf;
|
WRef<OperatorObj> outputOf;
|
||||||
Blob data;
|
Blob data;
|
||||||
Runtime runtime;
|
Runtime runtime;
|
||||||
|
|
||||||
|
@ -44,41 +44,12 @@ class TensorBaseObj : public Object {
|
||||||
DataType getDType() const { return dtype; }
|
DataType getDType() const { return dtype; }
|
||||||
Runtime getRuntime() const { return runtime; }
|
Runtime getRuntime() const { return runtime; }
|
||||||
|
|
||||||
// uint64_t getHash() const { return hash; }
|
void addInputOf(const Operator &op) { inputOf.emplace_back(op); }
|
||||||
|
void setOutputOf(const Operator &op) { outputOf = op; }
|
||||||
// void setInputOf(const OpVec &ops) {
|
OpVec getInputOf() { return wrefs_to_refs(inputOf); }
|
||||||
// inputOf.clear();
|
Operator getOutputOf() { return outputOf.lock(); }
|
||||||
// for (const auto &op : ops)
|
|
||||||
// inputOf.emplace_back(op);
|
|
||||||
// }
|
|
||||||
// void addInputOf(Operator op) { inputOf.emplace_back(op); }
|
|
||||||
// void setOutputOf(Operator op) { outputOf = op; }
|
|
||||||
|
|
||||||
// const OpVec &getInputOf() { return inputOf; }
|
|
||||||
// Operator *getOutputOf() { return outputOf; }
|
|
||||||
// std::pair<Operator *, int> getOutputOfWithIndex();
|
// std::pair<Operator *, int> getOutputOfWithIndex();
|
||||||
|
|
||||||
// const Dim &getDims() const { return dims; }
|
|
||||||
// void setDims(const Dim &dms) { dims = dms; }
|
|
||||||
|
|
||||||
// bool dataRand(int seed = 0) {
|
|
||||||
// if (data == nullptr)
|
|
||||||
// data = new VType[size()];
|
|
||||||
// if (!random_inited)
|
|
||||||
// initFastrand();
|
|
||||||
// // srand(seed);
|
|
||||||
// // faster rand generator; parallel
|
|
||||||
// size_t iEnd = size();
|
|
||||||
// // std::cerr << "Init beginned " << std::endl;
|
|
||||||
// #pragma omp parallel for
|
|
||||||
// for (size_t i = 0; i < iEnd; ++i)
|
|
||||||
// data[i] = fastrand(random_seed[omp_get_thread_num() * 16]) %
|
|
||||||
// 10000;
|
|
||||||
// // std::cerr << "Init finished" << std::endl;
|
|
||||||
// computed = ComputedFull;
|
|
||||||
// return true;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// bool setScalar(VType val) {
|
// bool setScalar(VType val) {
|
||||||
// if (data == nullptr || !dims.empty())
|
// if (data == nullptr || !dims.empty())
|
||||||
// return false;
|
// return false;
|
||||||
|
@ -102,35 +73,6 @@ class TensorBaseObj : public Object {
|
||||||
|
|
||||||
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
|
// VType getScalar() { return data == nullptr ? 0 : data[0]; }
|
||||||
|
|
||||||
// VType getData(const Dim &ds) {
|
|
||||||
// assert(data != nullptr);
|
|
||||||
// auto offset = getOffset(ds);
|
|
||||||
// return offset == (size_t)-1 ? 0 : data[getOffset(ds)];
|
|
||||||
// }
|
|
||||||
|
|
||||||
// VType getData(size_t pos) {
|
|
||||||
// assert(data != nullptr);
|
|
||||||
// assert(pos < size());
|
|
||||||
// return data[pos];
|
|
||||||
// }
|
|
||||||
|
|
||||||
// VType *getDataPtr() const { return data; }
|
|
||||||
|
|
||||||
// size_t getOffset(const Dim &ds) {
|
|
||||||
// auto nDim = ds.size();
|
|
||||||
// assert(dims.size() == nDim);
|
|
||||||
// if (ds.empty())
|
|
||||||
// return 0;
|
|
||||||
// for (size_t i = 0; i < nDim; ++i)
|
|
||||||
// if (ds[i] < 0 || ds[i] >= dims[i])
|
|
||||||
// return (size_t)-1;
|
|
||||||
// size_t idx = ds[0];
|
|
||||||
// size_t dm = 0;
|
|
||||||
// while (++dm < nDim)
|
|
||||||
// idx = idx * dims[dm] + ds[dm];
|
|
||||||
// return idx;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// VType getBroadcastData(const Dim &ds) {
|
// VType getBroadcastData(const Dim &ds) {
|
||||||
// assert(data != nullptr);
|
// assert(data != nullptr);
|
||||||
// auto offset = getBroadcastOffset(ds);
|
// auto offset = getBroadcastOffset(ds);
|
||||||
|
@ -155,96 +97,6 @@ class TensorBaseObj : public Object {
|
||||||
// idx = idx * dims[i] + ds[nBroadcastDim + i];
|
// idx = idx * dims[i] + ds[nBroadcastDim + i];
|
||||||
// return idx;
|
// return idx;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// void itInit() { it = Dim(dims.size(), 0); }
|
|
||||||
|
|
||||||
// void itReset() {
|
|
||||||
// itInit();
|
|
||||||
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
|
||||||
// it[i] = 0;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// bool itValid() {
|
|
||||||
// if (it.size() != dims.size())
|
|
||||||
// return false;
|
|
||||||
// for (size_t i = 0, iEnd = it.size(); i < iEnd; ++i)
|
|
||||||
// if (it[i] >= dims[i])
|
|
||||||
// return false;
|
|
||||||
// return true;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const Dim &itGet() { return it; }
|
|
||||||
|
|
||||||
// void itNext() {
|
|
||||||
// auto p = it.size() - 1;
|
|
||||||
// it[p] += 1;
|
|
||||||
// while (p >= 1) {
|
|
||||||
// if (it[p] == dims[p]) {
|
|
||||||
// it[p] = 0;
|
|
||||||
// it[--p] += 1;
|
|
||||||
// } else
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// size_t size() const {
|
|
||||||
// size_t sz = 1;
|
|
||||||
// auto dm = dims.size();
|
|
||||||
// while (dm > 0)
|
|
||||||
// sz *= dims[--dm];
|
|
||||||
// return sz;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// TensorType getType() const { return type; }
|
|
||||||
// void setType(TensorType ty) { type = ty; }
|
|
||||||
|
|
||||||
// static inline void initFastrand() {
|
|
||||||
// assert(omp_get_max_threads() <= 256);
|
|
||||||
// // srand(0); // constant seed for test
|
|
||||||
// // align random_seed to avoid false sharing
|
|
||||||
// for (int i = 0; i < 256 * 16; ++i) {
|
|
||||||
// // random_seed[i] = rand();
|
|
||||||
// // constant random seed for test
|
|
||||||
// random_seed[i] = i;
|
|
||||||
// }
|
|
||||||
// random_inited = true;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// static inline int fastrand(int &g_seed) {
|
|
||||||
// g_seed = (214013 * g_seed + 2531011);
|
|
||||||
// return (g_seed >> 16) & 0x7FFF;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// std::vector<std::vector<int>> const *getSplittingPoints() const {
|
|
||||||
// assert(!splittingPoints.empty());
|
|
||||||
// return &splittingPoints;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// bool setSplittingPoints(std::vector<std::vector<int>> value) {
|
|
||||||
// assert(!value.empty());
|
|
||||||
// splittingPoints = value;
|
|
||||||
// return true;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// void printSplittingPoints() {
|
|
||||||
// if (splittingPoints.empty())
|
|
||||||
// printf("Empty SplittingPoints");
|
|
||||||
// else {
|
|
||||||
// printf("[");
|
|
||||||
// for (auto &vs : splittingPoints) {
|
|
||||||
// printf("[");
|
|
||||||
// for (auto v : vs)
|
|
||||||
// printf("%2d,", v);
|
|
||||||
// printf("],");
|
|
||||||
// }
|
|
||||||
// printf("]");
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// void initSplittingPoints() {
|
|
||||||
// splittingPoints.resize(getDims().size()); }
|
|
||||||
|
|
||||||
// void printShape();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -2,7 +2,23 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
void GraphObj::updateConnection() { IT_TODO_HALT(); }
|
void GraphObj::addOperatorAndConnect(const Operator &op) {
|
||||||
|
ops.push_back(op);
|
||||||
|
for (auto &input : op->getInputs()) {
|
||||||
|
input->addInputOf(op);
|
||||||
|
if (auto pred = input->getOutputOf()) {
|
||||||
|
pred->addSuccessors(op);
|
||||||
|
op->addPredecessors(pred);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto &output : op->getOutputs()) {
|
||||||
|
output->setOutputOf(op);
|
||||||
|
for (auto &succ : output->getInputOf()) {
|
||||||
|
succ->addPredecessors(op);
|
||||||
|
op->addSuccessors(succ);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
string GraphObj::toString() const {
|
string GraphObj::toString() const {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
@ -11,8 +27,17 @@ string GraphObj::toString() const {
|
||||||
oss << tensor << "\n";
|
oss << tensor << "\n";
|
||||||
|
|
||||||
oss << "Graph operators:\n";
|
oss << "Graph operators:\n";
|
||||||
for (const auto &op : ops)
|
for (const auto &op : ops) {
|
||||||
oss << op << "\n";
|
vector<GuidBaseType> preds, succs;
|
||||||
|
for (auto &o : op->getPredecessors())
|
||||||
|
preds.emplace_back(o->getGuid());
|
||||||
|
for (auto &o : op->getSuccessors())
|
||||||
|
succs.emplace_back(o->getGuid());
|
||||||
|
oss << "OP " << op->getGuid();
|
||||||
|
oss << ", pred " << vecToString(preds);
|
||||||
|
oss << ", succ " << vecToString(succs);
|
||||||
|
oss << ", " << op << "\n";
|
||||||
|
}
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,12 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
|
OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
|
||||||
|
: type(opType), inputs(inputs), outputs(outputs) {
|
||||||
|
for (auto &t : inputs)
|
||||||
|
IT_ASSERT(t != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
bool OperatorObj::isLinearOp() const {
|
bool OperatorObj::isLinearOp() const {
|
||||||
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
|
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "core/tensor.h"
|
#include "core/tensor.h"
|
||||||
#include "core/blob.h"
|
#include "core/blob.h"
|
||||||
|
#include "core/operator.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
#include "utils/dataloader.h"
|
#include "utils/dataloader.h"
|
||||||
|
|
||||||
|
@ -13,7 +14,17 @@ VType TensorObj::getData(const Shape &pos) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
string TensorObj::toString() const {
|
string TensorObj::toString() const {
|
||||||
return "Tensor " + std::to_string(guid) + " shape " + vecToString(shape);
|
string ret = "Tensor " + std::to_string(guid) + ", shape " +
|
||||||
|
vecToString(shape) + ", dtype " + dtype.toString();
|
||||||
|
vector<GuidBaseType> inputOfGuid;
|
||||||
|
for (const auto &op : inputOf)
|
||||||
|
inputOfGuid.emplace_back(op.lock()->getGuid());
|
||||||
|
if (auto o = outputOf.lock())
|
||||||
|
ret += ", outputOf " + std::to_string(o->getGuid());
|
||||||
|
else
|
||||||
|
ret += ", outputOf None";
|
||||||
|
ret += ", inputOf " + vecToString(inputOfGuid);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t TensorObj::getOffset(const Shape &pos) const {
|
size_t TensorObj::getOffset(const Shape &pos) const {
|
||||||
|
|
|
@ -15,9 +15,20 @@ TEST(Graph, build_and_run) {
|
||||||
g->dataMalloc();
|
g->dataMalloc();
|
||||||
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||||
|
g->print();
|
||||||
|
// check inputOf and outputsOf for tensor
|
||||||
|
EXPECT_EQ(i0->getInputOf().size(), 1);
|
||||||
|
EXPECT_EQ(w0->getInputOf().size(), 1);
|
||||||
|
EXPECT_EQ(o0->getInputOf().size(), 0);
|
||||||
|
EXPECT_EQ(i0->getOutputOf(), nullptr);
|
||||||
|
EXPECT_EQ(w0->getOutputOf(), nullptr);
|
||||||
|
EXPECT_NE(o0->getOutputOf(), nullptr);
|
||||||
|
EXPECT_EQ(matmul->getPredecessors().size(), 0);
|
||||||
|
EXPECT_EQ(matmul->getSuccessors().size(), 0);
|
||||||
|
|
||||||
runtime->run(g);
|
runtime->run(g);
|
||||||
// check answer
|
// check execution results
|
||||||
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::UInt32, runtime);
|
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::UInt32, runtime);
|
||||||
ans->dataMalloc();
|
ans->dataMalloc();
|
||||||
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
|
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/concat.h"
|
||||||
|
#include "operators/conv.h"
|
||||||
|
#include "operators/pooling.h"
|
||||||
|
#include "operators/unary.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TEST(CUDA_Inception_v3_block, run) {
|
||||||
|
const int bs = 1, initialChannels = 192, h = 32;
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
auto g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto blockInput = g->addTensor({bs, initialChannels, h, h});
|
||||||
|
vector<vector<tuple<bool, int, int>>> configs =
|
||||||
|
// <isConv, f, r/s>
|
||||||
|
{
|
||||||
|
{{true, 64, 1}}, // a chain with one Conv
|
||||||
|
{{true, 48, 1}, {true, 64, 5}},
|
||||||
|
{{true, 64, 1}, {true, 96, 3}, {true, 96, 3}},
|
||||||
|
{{false, 192, 3}, {true, 32, 3}},
|
||||||
|
};
|
||||||
|
TensorVec outputs;
|
||||||
|
vector<OpVec> ops;
|
||||||
|
auto maxpool =
|
||||||
|
g->addOp<MaxPoolObj>(blockInput, nullptr, 3, 3, 1, 1, 1, 1, 1, 1);
|
||||||
|
auto chainInput = maxpool->getOutput();
|
||||||
|
for (auto &pathConfig : configs) {
|
||||||
|
int inputChannels = initialChannels;
|
||||||
|
auto input = chainInput;
|
||||||
|
ops.emplace_back();
|
||||||
|
for (auto &[isConv, f, r] : pathConfig) { // OpConfig
|
||||||
|
if (isConv) {
|
||||||
|
{ // Add Conv
|
||||||
|
auto w = g->addTensor({f, inputChannels, r, r});
|
||||||
|
auto conv =
|
||||||
|
g->addOp<ConvObj>(input, w, nullptr, r / 2, r / 2);
|
||||||
|
input = conv->getOutput();
|
||||||
|
ops.back().emplace_back(conv);
|
||||||
|
}
|
||||||
|
{ // Add Relu
|
||||||
|
auto relu = g->addOp<ReluObj>(input, nullptr);
|
||||||
|
input = relu->getOutput();
|
||||||
|
ops.back().emplace_back(relu);
|
||||||
|
}
|
||||||
|
inputChannels = f;
|
||||||
|
} else { // Add AveragePool
|
||||||
|
auto pool = g->addOp<AvgPoolObj>(input, nullptr, r, r, 1, 1,
|
||||||
|
r / 2, r / 2, 1, 1);
|
||||||
|
input = pool->getOutput();
|
||||||
|
ops.back().emplace_back(pool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
outputs.emplace_back(input);
|
||||||
|
}
|
||||||
|
auto concat = g->addOp<ConcatObj>(outputs, nullptr, 1);
|
||||||
|
g->print();
|
||||||
|
|
||||||
|
// check connection
|
||||||
|
EXPECT_EQ(maxpool->getSuccessors().size(), 4u);
|
||||||
|
EXPECT_EQ(chainInput->getInputOf().size(), 4u);
|
||||||
|
for (const auto &chainOps : ops) {
|
||||||
|
for (size_t i = 1; i < chainOps.size(); i++) {
|
||||||
|
auto prev = chainOps[i - 1];
|
||||||
|
auto cur = chainOps[i];
|
||||||
|
EXPECT_EQ(prev->getSuccessors().size(), 1u);
|
||||||
|
EXPECT_EQ(cur->getPredecessors().size(), 1u);
|
||||||
|
EXPECT_EQ(prev->getSuccessors()[0], cur);
|
||||||
|
EXPECT_EQ(prev, cur->getPredecessors()[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(concat->getPredecessors().size(), 4u);
|
||||||
|
|
||||||
|
// TODO: check outputs
|
||||||
|
g->dataMalloc();
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
}
|
||||||
|
}; // namespace infini
|
Loading…
Reference in New Issue