forked from jiuyuan/InfiniTensor
commit
65a3abf5dc
9
Makefile
9
Makefile
|
@ -1,10 +1,17 @@
|
|||
.PHONY : build clean install-python test-cpp test-onnx
|
||||
|
||||
TYPE ?= release
|
||||
CUDA ?= off
|
||||
|
||||
CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
||||
|
||||
ifeq ($(CUDA), ON)
|
||||
CMAKE_OPT += -DUSE_CUDA=ON
|
||||
endif
|
||||
|
||||
build:
|
||||
mkdir -p build/$(TYPE)
|
||||
cd build/$(TYPE) && cmake -DCMAKE_BUILD_TYPE=$(TYPE) ../.. && make -j8
|
||||
cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j8
|
||||
|
||||
clean:
|
||||
rm -rf build
|
||||
|
|
15
README.md
15
README.md
|
@ -5,16 +5,27 @@
|
|||
``` bash
|
||||
# Enter the root of InfiniTensor
|
||||
source test/script/env_lotus.sh
|
||||
mkdir build && cd build
|
||||
cmake -DUSE_CUDA=ON .. && make -j 12
|
||||
make CUDA=ON
|
||||
```
|
||||
|
||||
### Make Commands
|
||||
|
||||
- `make`/`make build`: Builds the project;
|
||||
- `make install-python`: Builds the project then install the python frontend;
|
||||
- `make test-cpp`: Builds the project then run cpp unit tests;
|
||||
- `make test-onnx`: Run python unit tests;
|
||||
|
||||
---
|
||||
|
||||
> Sets env: `CUDA=ON` to enable cuda.
|
||||
|
||||
### CMake Options
|
||||
|
||||
There are several configurable CMake options, see the [CMakeLists.txt file](/CMakeLists.txt#L5).
|
||||
|
||||
- If `USE_BACKTRACE` is `ON`, `libdw-dev` have to be installed. See the README of [backward-cpp](https://github.com/bombela/backward-cpp) for details.
|
||||
- If `USE_PROTOBUF` is `ON`, `protobuf` have to be installed. See the README of [protobuf](https://github.com/protocolbuffers/protobuf) for details.
|
||||
- If `USE_CUDA` is `ON`, `cuda` have to be installed.
|
||||
|
||||
## Contributor Guide
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ class GraphObj : public Object {
|
|||
inline TensorVec getInputs() const {
|
||||
TensorVec ret;
|
||||
for (const auto &t : tensors)
|
||||
if (!t->getOutputOf())
|
||||
if (!t->getSource())
|
||||
ret.emplace_back(t);
|
||||
return ret;
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ class GraphObj : public Object {
|
|||
inline TensorVec getOutputs() const {
|
||||
TensorVec ret;
|
||||
for (const auto &t : tensors)
|
||||
if (t->getInputOf().empty())
|
||||
if (t->getTargets().empty())
|
||||
ret.emplace_back(t);
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
#pragma once
|
||||
#include "core/tensor_base.h"
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
#if USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#endif
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -10,38 +15,56 @@ using Shape = vector<ShapeElem>;
|
|||
class TensorObj : public TensorBaseObj {
|
||||
private:
|
||||
Shape shape;
|
||||
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
|
||||
// scratch have a new id.
|
||||
size_t _size; // Cache of Π(shape).
|
||||
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
|
||||
// scratch have a new id.
|
||||
|
||||
void copyin(const void *ptr, size_t size) {
|
||||
runtime->copyBlobFromCPU(getRawDataPtr<void *>(), ptr, size);
|
||||
}
|
||||
void copyout(void *ptr, size_t size) const {
|
||||
runtime->copyBlobToCPU(ptr, getRawDataPtr<void *>(), size);
|
||||
}
|
||||
|
||||
public:
|
||||
TensorObj(const Shape &shape, DataType dtype, Runtime runtime);
|
||||
TensorObj(Shape shape, DataType dtype, Runtime runtime);
|
||||
virtual ~TensorObj() {}
|
||||
string toString() const override;
|
||||
|
||||
size_t size() const;
|
||||
size_t getBytes() const;
|
||||
size_t size() const { return _size; }
|
||||
size_t getBytes() const { return _size * dtype.getSize(); }
|
||||
|
||||
Shape getDims() const { return shape; }
|
||||
vector<size_t> getStride() const;
|
||||
size_t getOffset(const Shape &ds) const;
|
||||
using TensorBaseObj::getData;
|
||||
VType getData(const Shape &pos) const;
|
||||
size_t getOffset(const vector<int> &ds) const;
|
||||
void dataMalloc();
|
||||
UidBaseType getFuid() const { return fuid; }
|
||||
|
||||
void load(std::string file_path);
|
||||
void save(std::string file_path);
|
||||
|
||||
template <typename T> void copyData(const T *dptr) {
|
||||
// Copy elements from `data`.
|
||||
template <typename T> void copyin(const vector<T> &data) {
|
||||
IT_ASSERT(DataType::get<T>() == dtype);
|
||||
IT_ASSERT(data != nullptr);
|
||||
runtime->copyBlobFromCPU(getRawDataPtr<void *>(), dptr, getBytes());
|
||||
IT_ASSERT(data.size() >= _size);
|
||||
copyin(data.data(), getBytes());
|
||||
}
|
||||
|
||||
template <typename T> void copyData(vector<T> dataVector) {
|
||||
// Copy all the elements to a vector.
|
||||
template <typename T> auto copyout() const {
|
||||
IT_ASSERT(DataType::get<T>() == dtype);
|
||||
IT_ASSERT(dataVector.size() >= size());
|
||||
copyData(dataVector.data());
|
||||
std::vector<T> ans(_size);
|
||||
copyout(ans.data(), getBytes());
|
||||
return ans;
|
||||
}
|
||||
// Copy the element at `pos`.
|
||||
template <typename T> auto copyOne(const vector<int> &pos) const {
|
||||
IT_ASSERT(DataType::get<T>() == dtype);
|
||||
auto offset = getOffset(pos);
|
||||
auto bytes = dtype.getSize();
|
||||
T ans;
|
||||
runtime->copyBlobToCPU(
|
||||
&ans, getRawDataPtr<uint8_t *>() + offset * bytes, bytes);
|
||||
return ans;
|
||||
}
|
||||
|
||||
void copyData(const TensorObj *src);
|
||||
|
@ -51,17 +74,16 @@ class TensorObj : public TensorBaseObj {
|
|||
Tensor clone() const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->freeData();
|
||||
obj->inputOf.clear();
|
||||
obj->outputOf.reset();
|
||||
obj->targets.clear();
|
||||
obj->source.reset();
|
||||
return obj;
|
||||
}
|
||||
// TODO: clarify whether clone copies data
|
||||
Tensor clone(Runtime runtime) const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->runtime = runtime;
|
||||
obj->freeData();
|
||||
obj->inputOf.clear();
|
||||
obj->outputOf.reset();
|
||||
obj->targets.clear();
|
||||
obj->source.reset();
|
||||
if (hasData()) {
|
||||
obj->dataMalloc();
|
||||
obj->copyData(this);
|
||||
|
|
|
@ -19,8 +19,8 @@ class TensorBaseObj : public Object {
|
|||
int dim;
|
||||
|
||||
DataType dtype;
|
||||
vector<WRef<OperatorObj>> inputOf;
|
||||
WRef<OperatorObj> outputOf;
|
||||
vector<WRef<OperatorObj>> targets;
|
||||
WRef<OperatorObj> source;
|
||||
Blob data;
|
||||
Runtime runtime;
|
||||
|
||||
|
@ -41,16 +41,18 @@ class TensorBaseObj : public Object {
|
|||
IT_ASSERT(data != nullptr);
|
||||
return data->getPtr<T>();
|
||||
}
|
||||
VType getData(size_t offset) const;
|
||||
|
||||
DataType getDType() const { return dtype; }
|
||||
Runtime getRuntime() const { return runtime; }
|
||||
|
||||
void addInputOf(const Operator &op) { inputOf.emplace_back(op); }
|
||||
void setOutputOf(const Operator &op) { outputOf = op; }
|
||||
OpVec getInputOf() { return wrefs_to_refs(inputOf); }
|
||||
Operator getOutputOf() { return outputOf.lock(); }
|
||||
// std::pair<Operator *, int> getOutputOfWithIndex();
|
||||
void addTarget(const Operator &op) { targets.emplace_back(op); }
|
||||
void setSource(const Operator &op) { source = op; }
|
||||
|
||||
bool hasTarget() const { return !targets.empty(); }
|
||||
|
||||
OpVec getTargets() const { return wrefs_to_refs(targets); }
|
||||
Operator getSource() const { return source.lock(); }
|
||||
// std::pair<Operator *, int> getSourceWithIndex();
|
||||
|
||||
// bool setScalar(VType val) {
|
||||
// if (data == nullptr || !dims.empty())
|
||||
|
|
|
@ -41,7 +41,9 @@ class BatchNormObj : public OperatorObj {
|
|||
// output size will be 3 when training
|
||||
int numInputs() const override { return 5; }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
float getMomentum() const { return momentum; }
|
||||
float getEps() const { return eps; }
|
||||
bool getTraining() const { return training; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -8,13 +8,13 @@ from onnx.helper import (
|
|||
make_tensor_value_info,
|
||||
)
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx
|
||||
from pyinfinitensor.onnx import from_onnx, backend, cpu_runtime
|
||||
|
||||
|
||||
def make_and_import_model(graph: onnx.GraphProto):
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
from_onnx(model, cpu_runtime)
|
||||
|
||||
|
||||
class TestStringMethods(unittest.TestCase):
|
||||
|
@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
||||
)
|
||||
)
|
||||
parse_onnx(onnx.load(model_file))
|
||||
from_onnx(onnx.load(model_file), cpu_runtime)
|
||||
|
||||
def test_tensor(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
|
@ -66,10 +66,10 @@ class TestStringMethods(unittest.TestCase):
|
|||
|
||||
def test_batch_norm(self):
|
||||
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
|
||||
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1])
|
||||
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [3])
|
||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [3])
|
||||
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [3])
|
||||
var = make_tensor_value_info("var", TensorProto.FLOAT, [3])
|
||||
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2])
|
||||
batch_norm = make_node(
|
||||
"BatchNormalization",
|
||||
|
@ -289,11 +289,10 @@ class TestStringMethods(unittest.TestCase):
|
|||
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
parse_onnx(model)
|
||||
from_onnx(model, cpu_runtime)
|
||||
|
||||
def test_frontend(self):
|
||||
handler = backend.GraphHandler(runtime)
|
||||
handler = backend.GraphHandler(cpu_runtime)
|
||||
a = handler.tensor([1, 2, 3], 12)
|
||||
b = handler.tensor([1, 2, 3], 12)
|
||||
c = handler.tensor([1, 2, 3], 12)
|
||||
|
@ -306,8 +305,6 @@ class TestStringMethods(unittest.TestCase):
|
|||
y = handler.tensor([3, 2, 1], 12)
|
||||
handler.reshape(x, y, [3, 2, 1])
|
||||
|
||||
to_onnx(handler, "test_frontend")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -33,15 +33,15 @@ void GraphObj::addOperatorAndConnect(const Operator &op) {
|
|||
sorted = false;
|
||||
ops.push_back(op);
|
||||
for (auto &input : op->getInputs()) {
|
||||
input->addInputOf(op);
|
||||
if (auto pred = input->getOutputOf()) {
|
||||
input->addTarget(op);
|
||||
if (auto pred = input->getSource()) {
|
||||
pred->addSuccessors(op);
|
||||
op->addPredecessors(pred);
|
||||
}
|
||||
}
|
||||
for (auto &output : op->getOutputs()) {
|
||||
output->setOutputOf(op);
|
||||
for (auto &succ : output->getInputOf()) {
|
||||
output->setSource(op);
|
||||
for (auto &succ : output->getTargets()) {
|
||||
succ->addPredecessors(op);
|
||||
op->addSuccessors(succ);
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ bool GraphObj::topo_sort() {
|
|||
// this node is a head node.
|
||||
const auto is_head = std::all_of(
|
||||
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||
auto src = input->getOutputOf();
|
||||
auto src = input->getSource();
|
||||
return src // If the source node is in the waiting list,
|
||||
// means that this node is not the head node.
|
||||
? waiting.find(src) == waiting.end()
|
||||
|
|
|
@ -3,32 +3,33 @@
|
|||
#include "core/operator.h"
|
||||
#include "core/runtime.h"
|
||||
#include "utils/dataloader.h"
|
||||
#include <numeric>
|
||||
|
||||
namespace infini {
|
||||
|
||||
TensorObj::TensorObj(const Shape &shape, DataType dtype, Runtime runtime)
|
||||
: TensorBaseObj(shape.size(), dtype, runtime), shape(shape) {}
|
||||
|
||||
VType TensorObj::getData(const Shape &pos) const {
|
||||
return getData(getOffset(pos));
|
||||
}
|
||||
TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime)
|
||||
: TensorBaseObj(shape.size(), dtype, runtime), shape(std::move(shape_)),
|
||||
_size(shape.empty()
|
||||
? 0
|
||||
: std::accumulate(shape.begin(), shape.end(), 1,
|
||||
[](auto acc, auto x) { return acc * x; })) {}
|
||||
|
||||
string TensorObj::toString() const {
|
||||
string ret = "Tensor " + std::to_string(guid) + ", Fuid " +
|
||||
std::to_string(fuid) + ", shape " + vecToString(shape) +
|
||||
", dtype " + dtype.toString();
|
||||
vector<UidBaseType> inputOfGuid;
|
||||
for (const auto &op : inputOf)
|
||||
inputOfGuid.emplace_back(op.lock()->getGuid());
|
||||
if (auto o = outputOf.lock())
|
||||
ret += ", outputOf " + std::to_string(o->getGuid());
|
||||
vector<UidBaseType> targetGuids;
|
||||
for (const auto &op : targets)
|
||||
targetGuids.emplace_back(op.lock()->getGuid());
|
||||
if (auto o = source.lock())
|
||||
ret += ", source " + std::to_string(o->getGuid());
|
||||
else
|
||||
ret += ", outputOf None";
|
||||
ret += ", inputOf " + vecToString(inputOfGuid);
|
||||
ret += ", source None";
|
||||
ret += ", targets " + vecToString(targetGuids);
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t TensorObj::getOffset(const Shape &pos) const {
|
||||
size_t TensorObj::getOffset(const vector<int> &pos) const {
|
||||
auto nDim = pos.size();
|
||||
IT_ASSERT(shape.size() == nDim);
|
||||
if (pos.empty())
|
||||
|
@ -53,15 +54,6 @@ vector<size_t> TensorObj::getStride() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
size_t TensorObj::size() const {
|
||||
size_t ret = 1;
|
||||
for (const auto &d : shape)
|
||||
ret *= d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t TensorObj::getBytes() const { return size() * dtype.getSize(); }
|
||||
|
||||
void TensorObj::printData() const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
if (!runtime->isCpu())
|
||||
|
@ -148,15 +140,8 @@ bool TensorObj::equalData(const Tensor &rhs) const {
|
|||
}
|
||||
|
||||
void TensorObj::dataMalloc() {
|
||||
if (data != nullptr)
|
||||
return;
|
||||
// IT_ASSERT(data == nullptr);
|
||||
size_t bytesPerElement;
|
||||
if (getDType() == DataType::Float32)
|
||||
bytesPerElement = sizeof(float);
|
||||
else if (getDType() == DataType::UInt32)
|
||||
bytesPerElement = sizeof(uint32_t);
|
||||
data = runtime->allocBlob(size() * bytesPerElement);
|
||||
if (data == nullptr)
|
||||
data = runtime->allocBlob(getBytes());
|
||||
}
|
||||
|
||||
void TensorObj::copyData(const TensorObj *src) {
|
||||
|
|
|
@ -6,9 +6,4 @@ namespace infini {
|
|||
TensorBaseObj::TensorBaseObj(int dim, DataType dtype, Runtime runtime)
|
||||
: dim(dim), dtype(dtype), runtime(runtime) {}
|
||||
|
||||
VType TensorBaseObj::getData(size_t offset) const {
|
||||
// TODO: check cuda array
|
||||
return (data->getPtr<VType *>())[offset];
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
}; // namespace infini
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/operator_timer.h"
|
||||
#endif
|
||||
|
||||
|
@ -94,6 +98,34 @@ static int tensor_dtype(Tensor t) {
|
|||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
|
||||
#endif
|
||||
|
||||
static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Conv);
|
||||
auto conv = dynamic_cast<const ConvObj *>(op.get());
|
||||
return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(),
|
||||
conv->getDw(), conv->getSh(), conv->getSw());
|
||||
}
|
||||
|
||||
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::BatchNorm);
|
||||
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
|
||||
return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(),
|
||||
batchnorm->getTraining());
|
||||
}
|
||||
|
||||
static std::tuple<int, int, int, int, int, int, int, int>
|
||||
pool_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::MaxPool ||
|
||||
op->getOpType() == OpType::AvgPool);
|
||||
auto pool = dynamic_cast<const PoolingObj *>(op.get());
|
||||
return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(),
|
||||
pool->getDw(), pool->getPh(), pool->getPw(),
|
||||
pool->getSh(), pool->getSw());
|
||||
}
|
||||
|
||||
static int concat_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||
|
@ -118,6 +150,12 @@ static Shape reshape_shape_of(Operator op) {
|
|||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
#ifdef USE_CUDA
|
||||
.FUNCTION(cuda_runtime)
|
||||
#endif
|
||||
.FUNCTION(conv_attrs_of)
|
||||
.FUNCTION(batch_norm_attrs_of)
|
||||
.FUNCTION(pool_attrs_of)
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
|
@ -132,9 +170,21 @@ void init_graph_builder(py::module &m) {
|
|||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
#ifdef USE_CUDA
|
||||
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>(
|
||||
m, "CudaRuntime");
|
||||
#endif
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
||||
.def("shape", &TensorObj::getDims, policy::move)
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
|
||||
.def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)
|
||||
.def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
|
||||
.def("copyout_float", &TensorObj::copyout<float>, policy::move)
|
||||
.def("copyout_int32", &TensorObj::copyout<int32_t>, policy::move)
|
||||
.def("copyout_int64", &TensorObj::copyout<int64_t>, policy::move)
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getSource, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
|
||||
|
@ -165,7 +215,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("reshape", &Handler::reshape, policy::move)
|
||||
.def("concat", &Handler::concat, policy::move)
|
||||
.def("gather", &Handler::gather, policy::move)
|
||||
.def("reduceMean", &Handler::reduceMean, policy::move)
|
||||
.def("reduce_mean", &Handler::reduceMean, policy::move)
|
||||
.def("slice", &Handler::slice, policy::move)
|
||||
.def("pad", &Handler::pad, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
|
|
|
@ -64,7 +64,7 @@ class MemboundInterpreter : public Kernel {
|
|||
vector<uint32_t> valsUint(vals.size());
|
||||
for (size_t i = 0; i < vals.size(); ++i)
|
||||
valsUint[i] = (uint32_t)vals[i];
|
||||
output->copyData(valsUint);
|
||||
output->copyin(valsUint);
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
|
@ -81,4 +81,4 @@ class MemboundInterpreter : public Kernel {
|
|||
REGISTER_KERNEL(Device::CPU, OpType::MemBound, DataType::UInt32,
|
||||
MemboundInterpreter, "MemboundInterpreter_CPU");
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class BatchNormCudnn : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
|
@ -28,9 +29,11 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
|
|||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dimArray[i] = dims[i];
|
||||
strideArray[i] = op->getInputs(0)->getStride()[i];
|
||||
dimPArray[i] = op->getInputs(1)->getDims()[i];
|
||||
stridePArray[i] = op->getInputs(1)->getStride()[i];
|
||||
dimPArray[i] = 1;
|
||||
stridePArray[i] = 1;
|
||||
}
|
||||
dimPArray[1] = op->getInputs(0)->getDims()[1];
|
||||
stridePArray[1] = op->getInputs(0)->getStride()[1];
|
||||
// get inputs
|
||||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
|
|
|
@ -19,12 +19,9 @@ BatchNormObj::inferShape(const TensorVec &inputs) const {
|
|||
auto var = inputs[2];
|
||||
auto scale = inputs[3];
|
||||
auto bias = inputs[4];
|
||||
if (input->getDims().size() < 2)
|
||||
return {};
|
||||
Shape dims(input->getDims().size(), 1);
|
||||
dims[1] = input->getDims()[1]; //
|
||||
if (mean->getDims() != dims || var->getDims() != dims ||
|
||||
scale->getDims() != dims || bias->getDims() != dims)
|
||||
auto c = std::vector<int>{input->getDims()[1]};
|
||||
if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c ||
|
||||
bias->getDims() != c)
|
||||
return {};
|
||||
return {{input->getDims()}};
|
||||
}
|
||||
|
|
|
@ -5,10 +5,26 @@ namespace infini {
|
|||
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
||||
bool transB, [[maybe_unused]] Tensor bias, ActType act)
|
||||
: OperatorObj(OpType::Matmul, {A, B}, {C}), transA(transA), transB(transB),
|
||||
act(act), b(A->getDims()[0]),
|
||||
m(transA ? A->getDims()[2] : A->getDims()[1]),
|
||||
n(transB ? B->getDims()[1] : B->getDims()[2]),
|
||||
k(transA ? A->getDims()[1] : A->getDims()[2]) {
|
||||
act(act), b(1) {
|
||||
auto shape_a = A->getDims();
|
||||
auto shape_b = B->getDims();
|
||||
IT_ASSERT(shape_a.size() == shape_b.size());
|
||||
switch (shape_a.size()) {
|
||||
case 0:
|
||||
case 1:
|
||||
IT_ASSERT(false);
|
||||
case 2:
|
||||
break;
|
||||
default:
|
||||
for (size_t i = 0; i < shape_a.size() - 2; ++i) {
|
||||
IT_ASSERT(shape_a[i] == shape_b[i]);
|
||||
b *= shape_a[i];
|
||||
}
|
||||
break;
|
||||
}
|
||||
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
|
||||
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
|
||||
k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -22,19 +38,11 @@ string MatmulObj::toString() const {
|
|||
}
|
||||
|
||||
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
|
||||
auto A = inputs[0], B = inputs[1];
|
||||
// if (A->getType() == Tensor::Weight && B->getType() == Tensor::Weight)
|
||||
// return false;
|
||||
if (!(A->getDims().size() == 3 && B->getDims().size() == 3))
|
||||
return {};
|
||||
if (!(A->getDims()[0] == B->getDims()[0]))
|
||||
return {};
|
||||
if (!((transA ? A->getDims()[1] : A->getDims()[2]) ==
|
||||
(transB ? B->getDims()[2] : B->getDims()[1])))
|
||||
return {};
|
||||
int b(A->getDims()[0]), m(transA ? A->getDims()[2] : A->getDims()[1]),
|
||||
n(transB ? B->getDims()[1] : B->getDims()[2]);
|
||||
return {{{b, m, n}}};
|
||||
auto shape_a = inputs[0]->getDims();
|
||||
auto it = shape_a.rbegin();
|
||||
*it++ = n;
|
||||
*it++ = m;
|
||||
return {{std::move(shape_a)}};
|
||||
}
|
||||
|
||||
vector<int> MatmulObj::getWorkloadVector() const {
|
||||
|
|
|
@ -59,13 +59,13 @@ void loadTensorData(TensorObj *tensor, std::string file_path) {
|
|||
for (int i = 0; i < temp.data_float_size(); ++i) {
|
||||
data_temp.push_back(temp.data_float(i));
|
||||
}
|
||||
tensor->copyData(data_temp);
|
||||
tensor->copyin(data_temp);
|
||||
} else if (tensor->getDType() == DataType::UInt32) {
|
||||
std::vector<uint32_t> data_temp;
|
||||
for (int i = 0; i < temp.data_uint32_size(); ++i) {
|
||||
data_temp.push_back(temp.data_uint32(i));
|
||||
}
|
||||
tensor->copyData(data_temp);
|
||||
tensor->copyin(data_temp);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
|
|
@ -15,17 +15,17 @@ TEST(Graph, build_and_run) {
|
|||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||
g->print();
|
||||
// check inputOf and outputsOf for tensor
|
||||
EXPECT_EQ(i0->getInputOf().size(), 1u);
|
||||
EXPECT_EQ(w0->getInputOf().size(), 1u);
|
||||
EXPECT_EQ(o0->getInputOf().size(), 0u);
|
||||
EXPECT_EQ(i0->getOutputOf(), nullptr);
|
||||
EXPECT_EQ(w0->getOutputOf(), nullptr);
|
||||
EXPECT_NE(o0->getOutputOf(), nullptr);
|
||||
// check targets and source for tensor
|
||||
EXPECT_EQ(i0->getTargets().size(), 1u);
|
||||
EXPECT_EQ(w0->getTargets().size(), 1u);
|
||||
EXPECT_EQ(o0->getTargets().size(), 0u);
|
||||
EXPECT_EQ(i0->getSource(), nullptr);
|
||||
EXPECT_EQ(w0->getSource(), nullptr);
|
||||
EXPECT_NE(o0->getSource(), nullptr);
|
||||
EXPECT_EQ(matmul->getPredecessors().size(), 0u);
|
||||
EXPECT_EQ(matmul->getSuccessors().size(), 0u);
|
||||
|
||||
|
@ -33,7 +33,7 @@ TEST(Graph, build_and_run) {
|
|||
// check execution results
|
||||
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::UInt32, runtime);
|
||||
ans->dataMalloc();
|
||||
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
|
||||
ans->copyin(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
|
||||
EXPECT_TRUE(o0->equalData(ans));
|
||||
}
|
||||
|
||||
|
@ -84,8 +84,8 @@ TEST(Graph, perf_engine) {
|
|||
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);
|
||||
|
||||
g->dataMalloc();
|
||||
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
runtime->run(g, true, true);
|
||||
double perfTime = runtime->getPerfTime(g);
|
||||
// The example matmul takes 0.0036ms with one core
|
||||
|
@ -94,7 +94,7 @@ TEST(Graph, perf_engine) {
|
|||
// check answer
|
||||
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::UInt32, runtime);
|
||||
ans->dataMalloc();
|
||||
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
|
||||
ans->copyin(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
|
||||
EXPECT_TRUE(matmul->getOutput()->equalData(ans));
|
||||
}
|
||||
|
||||
|
@ -105,8 +105,8 @@ TEST(Graph, test_tensor_id) {
|
|||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto i1 = g->addTensor(i0->clone());
|
||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||
g->print();
|
||||
|
@ -123,8 +123,8 @@ TEST(Graph, test_OpVec_ctor) {
|
|||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto o1 = g->addTensor(o0->clone());
|
||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||
g->addOp<ReluObj>(o1, nullptr);
|
||||
|
@ -139,8 +139,8 @@ TEST(Graph, test_OpVec_ctor) {
|
|||
map<pair<int, int>, int> inputOutput2Cnt = {
|
||||
{{1, 0}, 2}, {{1, 1}, 1}, {{0, 1}, 1}};
|
||||
for (auto t : g2->getTensors()) {
|
||||
pair<int, int> key = {t->getInputOf().size(),
|
||||
t->getOutputOf() != nullptr};
|
||||
pair<int, int> key = {t->getTargets().size(),
|
||||
t->getSource() != nullptr};
|
||||
EXPECT_GE(inputOutput2Cnt[key], 0);
|
||||
inputOutput2Cnt[key]--;
|
||||
}
|
||||
|
|
|
@ -19,11 +19,11 @@ namespace infini {
|
|||
// Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
// Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||
// g->dataMalloc();
|
||||
// i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
// w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
// i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
// w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
// auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||
// g->print();
|
||||
// // check inputOf and outputsOf for tensor
|
||||
// // check targets and source for tensor
|
||||
// SearchEngine searchEngine(runtime, make_ref<NMutator>());
|
||||
// searchEngine.run(g);
|
||||
// // check execution results
|
||||
|
@ -46,7 +46,7 @@ TEST(Graph, search_withdm) {
|
|||
auto conv1 = g->addOpWithOutputs<ConvObj>(t3, w3, t4, 1, 1);
|
||||
auto add1 = g->addOpWithOutputs<AddObj>(t4, t5, t6);
|
||||
g->dataMalloc();
|
||||
// check inputOf and outputsOf for tensor
|
||||
// check targets and source for tensor
|
||||
SearchEngine searchEngine(runtime, make_ref<DummyMutator>(10));
|
||||
searchEngine.run(g);
|
||||
// check execution results
|
||||
|
|
|
@ -14,10 +14,10 @@ TEST(Prtotbuf, save_and_load) {
|
|||
Tensor u0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
Tensor u1 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
i0->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyData(vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
u0->copyData(vector<uint32_t>{1, 3, 5, 7, 9, 2, 4, 6, 8, 10, 0, 0});
|
||||
u1->copyData(vector<uint32_t>{1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0});
|
||||
i0->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
w0->copyin(vector<float>{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
u0->copyin(vector<uint32_t>{1, 3, 5, 7, 9, 2, 4, 6, 8, 10, 0, 0});
|
||||
u1->copyin(vector<uint32_t>{1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0});
|
||||
i0->save("i0.pb");
|
||||
w0->printData();
|
||||
w0->load("i0.pb");
|
||||
|
|
|
@ -22,8 +22,8 @@ TEST(CUDA_BatchNorm, run) {
|
|||
// Build input data on CPU
|
||||
gCpu->dataMalloc();
|
||||
iCpu->setData(IncrementalGenerator());
|
||||
meanCpu->copyData(vector<float>{1, 6, 9});
|
||||
varCpu->copyData(vector<float>{4, 1, 9});
|
||||
meanCpu->copyin(vector<float>{1, 6, 9});
|
||||
varCpu->copyin(vector<float>{4, 1, 9});
|
||||
scaleCpu->setData(OneGenerator());
|
||||
biasCpu->setData(ZeroGenerator());
|
||||
|
||||
|
|
|
@ -181,8 +181,8 @@ TEST(Gather, Cuda) {
|
|||
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
|
||||
auto index = gCpu->addTensor({2, 2}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6});
|
||||
index->copyData(vector<uint32_t>{0, 1, 1, 2});
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
||||
index->copyin(vector<uint32_t>{0, 1, 1, 2});
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
|
@ -203,7 +203,7 @@ TEST(Gather, Cuda) {
|
|||
auto index = gCpu->addTensor({1, 2}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
index->copyData(vector<uint32_t>{0, 2});
|
||||
index->copyin(vector<uint32_t>{0, 2});
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
|
@ -224,7 +224,7 @@ TEST(Gather, Cuda) {
|
|||
auto index = gCpu->addTensor({3, 1}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
index->copyData(vector<uint32_t>{0, 3, 1});
|
||||
index->copyin(vector<uint32_t>{0, 3, 1});
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ TEST(CUDA_Inception_v3_block, run) {
|
|||
|
||||
// check connection
|
||||
EXPECT_EQ(maxpool->getSuccessors().size(), 4u);
|
||||
EXPECT_EQ(chainInput->getInputOf().size(), 4u);
|
||||
EXPECT_EQ(chainInput->getTargets().size(), 4u);
|
||||
for (const auto &chainOps : ops) {
|
||||
for (size_t i = 1; i < chainOps.size(); i++) {
|
||||
auto prev = chainOps[i - 1];
|
||||
|
|
|
@ -18,7 +18,7 @@ void test_reducemean(const Shape &shape, const vector<float> &data,
|
|||
// Build input data on CPU
|
||||
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
icpu->dataMalloc();
|
||||
icpu->copyData(data);
|
||||
icpu->copyin(data);
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
|
|
|
@ -13,8 +13,8 @@ TEST(Resize, Cuda_downsample_sizes_nearest) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 1, 3});
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 1, 3});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -38,8 +38,8 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({2}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
sizes->copyData(vector<uint32_t>{7, 8});
|
||||
input->copyin(vector<float>{1, 2, 3, 4});
|
||||
sizes->copyin(vector<uint32_t>{7, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -68,8 +68,8 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({2}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
sizes->copyData(vector<uint32_t>{7, 8});
|
||||
input->copyin(vector<float>{1, 2, 3, 4});
|
||||
sizes->copyin(vector<uint32_t>{7, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -98,9 +98,9 @@ TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 8, 8});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 8, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -131,9 +131,9 @@ TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({2}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{8, 8});
|
||||
sizes->copyin(vector<uint32_t>{8, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -164,9 +164,9 @@ TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 8, 8});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 8, 8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -197,8 +197,8 @@ TEST(Resize, Cuda_downsample_scales_nearest) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyData(vector<float>{1, 1, 0.6, 0.6});
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -221,8 +221,8 @@ TEST(Resize, Cuda_upsample_scales_nearest) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
scales->copyData(vector<float>{1, 1, 2, 3});
|
||||
input->copyin(vector<float>{1, 2, 3, 4});
|
||||
scales->copyin(vector<float>{1, 1, 2, 3});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -247,8 +247,8 @@ TEST(Resize, Cuda_upsample_scales_nearest_axes_3_2) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({2}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
scales->copyData(vector<float>{3, 2});
|
||||
input->copyin(vector<float>{1, 2, 3, 4});
|
||||
scales->copyin(vector<float>{3, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -273,8 +273,8 @@ TEST(Resize, Cuda_downsample_scales_linear) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyData(vector<float>{1, 1, 0.6, 0.6});
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -297,8 +297,8 @@ TEST(Resize, Cuda_downsample_scales_linear_aligncorners) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyData(vector<float>{1, 1, 0.6, 0.6});
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
scales->copyin(vector<float>{1, 1, 0.6, 0.6});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -323,8 +323,8 @@ TEST(Resize, Cuda_upsample_scales_linear) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
scales->copyData(vector<float>{1, 1, 2, 2});
|
||||
input->copyin(vector<float>{1, 2, 3, 4});
|
||||
scales->copyin(vector<float>{1, 1, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -349,8 +349,8 @@ TEST(Resize, Cuda_upsample_scales_linear_align_corners) {
|
|||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(vector<float>{1, 2, 3, 4});
|
||||
scales->copyData(vector<float>{1, 1, 2, 2});
|
||||
input->copyin(vector<float>{1, 2, 3, 4});
|
||||
scales->copyin(vector<float>{1, 1, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -377,9 +377,9 @@ TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 3, 1});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 3, 1});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -406,10 +406,10 @@ TEST(Resize, Cuda_tf_crop_and_resize) {
|
|||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
auto roi = gCpu->addTensor({8}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 3, 3});
|
||||
roi->copyData(vector<float>{0, 0, 0.4, 0.6, 1, 1, 0.6, 0.8});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 3, 3});
|
||||
roi->copyin(vector<float>{0, 0, 0.4, 0.6, 1, 1, 0.6, 0.8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -437,10 +437,10 @@ TEST(Resize, Cuda_tf_crop_and_resize_axes_3_2) {
|
|||
auto sizes = gCpu->addTensor({2}, DataType::UInt32);
|
||||
auto roi = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{3, 3});
|
||||
roi->copyData(vector<float>{0.6, 0.4, 0.8, 0.6});
|
||||
sizes->copyin(vector<uint32_t>{3, 3});
|
||||
roi->copyin(vector<float>{0.6, 0.4, 0.8, 0.6});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -467,9 +467,9 @@ TEST(Resize, Cuda_downsample_scales_cubic) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 0.8, 0.8});
|
||||
scales->copyin(vector<float>{1.0, 1.0, 0.8, 0.8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -494,9 +494,9 @@ TEST(Resize, Cuda_downsample_scales_cubic_align_corners) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 0.8, 0.8});
|
||||
scales->copyin(vector<float>{1.0, 1.0, 0.8, 0.8});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -522,9 +522,9 @@ TEST(Resize, Cuda_upsample_scales_cubic) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 2, 2});
|
||||
scales->copyin(vector<float>{1.0, 1.0, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -559,9 +559,9 @@ TEST(Resize, Cuda_upsample_scales_cubic_align_corners) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 2, 2});
|
||||
scales->copyin(vector<float>{1.0, 1.0, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -598,9 +598,9 @@ TEST(Resize, Cuda_upsample_scales_cubic_asymmetric) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto scales = gCpu->addTensor({4}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
scales->copyData(vector<float>{1.0, 1.0, 2, 2});
|
||||
scales->copyin(vector<float>{1.0, 1.0, 2, 2});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -633,9 +633,9 @@ TEST(Resize, Cuda_downsample_sizes_cubic) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 3, 3});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 3, 3});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
@ -667,9 +667,9 @@ TEST(Resize, Cuda_upsample_sizes_cubic) {
|
|||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyData(
|
||||
input->copyin(
|
||||
vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 9, 10});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 9, 10});
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
|
|
@ -9,10 +9,10 @@ TEST(BatchNorm, ShapeInference) {
|
|||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
|
||||
Tensor mean = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
||||
Tensor var = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
||||
Tensor scaler = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
||||
Tensor bias = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
||||
Tensor mean = g->addTensor({3}, DataType::Float32);
|
||||
Tensor var = g->addTensor({3}, DataType::Float32);
|
||||
Tensor scaler = g->addTensor({3}, DataType::Float32);
|
||||
Tensor bias = g->addTensor({3}, DataType::Float32);
|
||||
auto op = g->addOp<BatchNormObj>(i, nullptr, mean, var, scaler, bias,
|
||||
0.9, 1e-5);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));
|
||||
|
|
|
@ -61,7 +61,7 @@ TEST(Conv, NaiveCPU) {
|
|||
auto ans =
|
||||
make_ref<TensorObj>(Shape{1, 2, 2, 2}, DataType::UInt32, runtime);
|
||||
ans->dataMalloc();
|
||||
ans->copyData(
|
||||
ans->copyin(
|
||||
vector<uint32_t>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
|
||||
EXPECT_TRUE(conv->getOutput()->equalData(ans));
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor i = g->addTensor({1, 1, 2, 4}, DataType::UInt32);
|
||||
Tensor sizes = g->addTensor({4}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{1, 1, 1, 3});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 1, 3});
|
||||
auto op = g->addOp<ResizeObj>(
|
||||
i, nullptr, std::nullopt, sizes, nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
|
@ -24,7 +24,7 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor i = g->addTensor({1, 1, 2, 4}, DataType::UInt32);
|
||||
Tensor sizes = g->addTensor({2}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{1, 3});
|
||||
sizes->copyin(vector<uint32_t>{1, 3});
|
||||
auto op = g->addOp<ResizeObj>(
|
||||
i, nullptr, vector<int>{2, 3}, sizes, nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch);
|
||||
|
@ -36,7 +36,7 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor i = g->addTensor({1, 3, 2, 4}, DataType::UInt32);
|
||||
Tensor sizes = g->addTensor({2}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{7, 8});
|
||||
sizes->copyin(vector<uint32_t>{7, 8});
|
||||
auto op = g->addOp<ResizeObj>(
|
||||
i, nullptr, vector<int>{2, 3}, sizes, nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notLarger);
|
||||
|
@ -48,7 +48,7 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor i = g->addTensor({1, 3, 2, 4}, DataType::UInt32);
|
||||
Tensor sizes = g->addTensor({3}, DataType::UInt32);
|
||||
sizes->dataMalloc();
|
||||
sizes->copyData(vector<uint32_t>{2, 6, 8});
|
||||
sizes->copyin(vector<uint32_t>{2, 6, 8});
|
||||
auto op = g->addOp<ResizeObj>(
|
||||
i, nullptr, vector<int>{1, 2, 3}, sizes, nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::notSmaller);
|
||||
|
@ -60,7 +60,7 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor i = g->addTensor({1, 1, 4, 4}, DataType::UInt32);
|
||||
Tensor scales = g->addTensor({3}, DataType::Float32);
|
||||
scales->dataMalloc();
|
||||
scales->copyData(vector<float>{1, 0.8, 0.8});
|
||||
scales->copyin(vector<float>{1, 0.8, 0.8});
|
||||
auto op = g->addOp<ResizeObj>(i, nullptr, vector<int>{1, 2, 3}, nullptr,
|
||||
scales, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 3, 3}));
|
||||
|
@ -71,7 +71,7 @@ TEST(Resize, ShapeInference) {
|
|||
Tensor i = g->addTensor({1, 1, 2, 2}, DataType::UInt32);
|
||||
Tensor scales = g->addTensor({4}, DataType::Float32);
|
||||
scales->dataMalloc();
|
||||
scales->copyData(vector<float>{1, 1, 2, 2});
|
||||
scales->copyin(vector<float>{1, 1, 2, 2});
|
||||
auto op = g->addOp<ResizeObj>(i, nullptr, std::nullopt, nullptr, scales,
|
||||
nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 4, 4}));
|
||||
|
|
Loading…
Reference in New Issue