forked from jiuyuan/InfiniTensor
feat: support new graph to old runtime
This commit is contained in:
parent
e637f9e7dd
commit
a8f8d504f4
|
@ -2,6 +2,7 @@
|
|||
#include "core/lazy_allocator.h"
|
||||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
#include "computation/graph.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -113,6 +114,8 @@ class GraphObj : public Object {
|
|||
|
||||
bool checkValid() const;
|
||||
|
||||
void transformFromGraphTopo(refactor::computation::Graph &graph, Runtime runtime);
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Add reverse connections and Op relationship in ctor.
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -10,10 +10,11 @@ typedef struct {
|
|||
int wholeNDim[MAX_DIM]; // dim size after padding or before slicing
|
||||
int partNDim[MAX_DIM]; // dim size before padding or after slicing
|
||||
int partStride[MAX_DIM]; // stride before padding or after slicing
|
||||
int DType;
|
||||
} TransMetaData;
|
||||
|
||||
namespace infini {
|
||||
void pad_slice_kernel(float *partData, float *wholeData,
|
||||
void pad_slice_kernel(void *partData, void *wholeData,
|
||||
const TransMetaData &metadata, int nDims, int num,
|
||||
bool isPad);
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
#define OPERATOR_UTIL_H
|
||||
|
||||
#include "core/tensor.h"
|
||||
#include "core/graph.h"
|
||||
#include "computation/graph.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -10,6 +12,18 @@ namespace infini {
|
|||
Shape infer_broadcast(const Shape &A, const Shape &B);
|
||||
// Launch the real axis based on rank and current axis
|
||||
int get_real_axis(const int &axis, const int &rank);
|
||||
|
||||
// transform RefactorGraph node to InfiniTensorGraph operator
|
||||
void addOperatorFromGraphTopo(GraphObj &g,
|
||||
std::shared_ptr<refactor::computation::Operator> nodeInfo,
|
||||
std::vector<size_t> input, std::vector<size_t> output,
|
||||
std::unordered_map<size_t, Tensor> &edgeToTensor,
|
||||
std::vector<refactor::computation::Edge> edges);
|
||||
|
||||
void addEdgeToTensor(GraphObj &g, size_t index,
|
||||
std::shared_ptr<refactor::computation::Tensor> tensor,
|
||||
std::unordered_map<size_t, Tensor> &edgeToTensor,
|
||||
Runtime runtime);
|
||||
} // namespace infini
|
||||
|
||||
#endif
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "core/graph.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "utils/operator_utils.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
|
@ -349,4 +350,37 @@ bool GraphObj::checkValid() const {
|
|||
return true;
|
||||
}
|
||||
|
||||
void GraphObj::transformFromGraphTopo(refactor::computation::Graph &graph, Runtime runtime) {
|
||||
// create ops and tensors
|
||||
ops.clear();
|
||||
tensors.clear();
|
||||
auto const& nodes = graph.internal().nodes;
|
||||
auto const& edges = graph.internal().edges;
|
||||
std::unordered_map<size_t, Tensor> edgeToTensor;
|
||||
|
||||
for (auto [nodeIdx, inputs, outputs] : graph.internal().topology) {
|
||||
// not dynamic_node
|
||||
if (!std::all_of(outputs.begin(), outputs.end(), [&](auto e) { return edges[e].tensor->hasData(); })) {
|
||||
auto nodeInfo = nodes[nodeIdx];
|
||||
IT_ASSERT(refactor::computation::OpType::tryParse(nodeInfo.op->opType.name().data()));
|
||||
std::vector<size_t> in, out;
|
||||
for (auto i : inputs) {
|
||||
if (edgeToTensor.find(i) == edgeToTensor.end()) {
|
||||
addEdgeToTensor(*this, i, edges[i].tensor, edgeToTensor, runtime);
|
||||
}
|
||||
in.emplace_back(i);
|
||||
}
|
||||
for (auto i : outputs) {
|
||||
if (edgeToTensor.find(i) == edgeToTensor.end()) {
|
||||
addEdgeToTensor(*this, i, edges[i].tensor, edgeToTensor, runtime);
|
||||
}
|
||||
out.emplace_back(i);
|
||||
}
|
||||
IT_ASSERT(out.size() == outputs.size());
|
||||
IT_ASSERT(in.size() == inputs.size());
|
||||
addOperatorFromGraphTopo(*this, nodeInfo.op, in, out, edgeToTensor, edges);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -57,11 +57,11 @@ HashType OperatorObj::hash() const {
|
|||
bool OperatorObj::checkValid(GraphObj *graph) {
|
||||
auto optShapes = inferShape();
|
||||
if (!optShapes) // shape inference failed
|
||||
return false;
|
||||
IT_ASSERT(false);
|
||||
|
||||
const vector<Shape> &shapes = *optShapes;
|
||||
if (shapes.size() != outputs.size())
|
||||
return false;
|
||||
IT_ASSERT(false);
|
||||
if (graph) { // if graph != nullptr, outputs should be created
|
||||
auto dataTypes = inferDataType();
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
|
@ -70,8 +70,11 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
|||
}
|
||||
} else { // if outputs have been created, check their shapes
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
if (shapes[i] != outputs[i]->getDims())
|
||||
return false;
|
||||
if (shapes[i] != outputs[i]->getDims()) {
|
||||
std::cout<<"shapes"<<vecToString(shapes[i])<<std::endl;
|
||||
std::cout<<vecToString(outputs[i]->getDims())<<std::endl;
|
||||
IT_ASSERT(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -1,10 +1,22 @@
|
|||
#include "common/error_handler.h"
|
||||
#include "communication/operators.h"
|
||||
#include "core/graph.h"
|
||||
#include "computation/graph.h"
|
||||
#include "onnx/operators.h"
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/operator_timer.h"
|
||||
#endif
|
||||
#ifdef USE_BANG
|
||||
#include "bang/bang_runtime.h"
|
||||
#endif
|
||||
#ifdef USE_INTELCPU
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "intelcpu/operator_timer.h"
|
||||
#endif
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
@ -28,7 +40,17 @@ class Handler {
|
|||
fmt::format("Variable {} not exist", name));
|
||||
}
|
||||
auto const &graph() const { return _g.internal(); }
|
||||
void runCuda() { TODO("Not implemented"); }
|
||||
#ifdef USE_CUDA
|
||||
void runCuda() {
|
||||
using namespace infini;
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>(0);
|
||||
auto graph = make_ref<GraphObj>(std::move(cudaRuntime));
|
||||
graph->transformFromGraphTopo(_g, cudaRuntime);
|
||||
//graph->print();
|
||||
graph->dataMalloc();
|
||||
graph->getRuntime()->run(graph);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
using TExport = std::tuple<Name, int, std::vector<std::variant<Name, int>>>;
|
||||
|
@ -196,8 +218,10 @@ void register_refactor(py::module &m) {
|
|||
py::class_<Handler, std::shared_ptr<Handler>>(m, "Graph")
|
||||
.def("fill_edge_info", &Handler::fillEdgeInfo)
|
||||
.def("substitute", &Handler::substitute)
|
||||
.def("set_input", &Handler::setInput)
|
||||
.def("run_cuda", &Handler::runCuda);
|
||||
#ifdef USE_CUDA
|
||||
.def("run_cuda", &Handler::runCuda)
|
||||
#endif
|
||||
.def("set_input", &Handler::setInput);
|
||||
py::class_<NodeExport>(m, "NodeExport")
|
||||
.def(py::init<std::shared_ptr<Handler>>())
|
||||
.def("global_inputs", &NodeExport::globalInputs)
|
||||
|
@ -212,4 +236,6 @@ void register_refactor(py::module &m) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(backend, m) { register_refactor(m); }
|
||||
PYBIND11_MODULE(backend, m) {
|
||||
register_refactor(m);
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
|
|||
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
|
||||
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&aDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW,
|
||||
|
@ -110,9 +109,9 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
float *const aData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const bData = (op->getInputs(1)->getRawDataPtr<float *>());
|
||||
float *const cData = (op->getOutput()->getRawDataPtr<float *>());
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
|
@ -134,6 +133,10 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
|||
else if (op->getOpType() == OpType::Pow)
|
||||
pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
else if (op->getOpType() == OpType::Add) {
|
||||
add_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
}
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
@ -152,6 +155,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn,
|
|||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
|
||||
"Div_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Int64, ElementWiseCuda,
|
||||
"Add_CUDA_Int64");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda,
|
||||
"Pow__CUDA_Float32");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -5,7 +5,7 @@ constexpr unsigned int num_threads() { return 32 * 4; }
|
|||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1,
|
||||
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3,
|
||||
int c0, int c1, int c2, int c3) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
@ -27,14 +27,15 @@ __global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1,
|
|||
int b1_index = c1_index % b1;
|
||||
int b2_index = c2_index % b2;
|
||||
int b3_index = c3_index % b3;
|
||||
z[i] = x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + a2_index * a3 +
|
||||
((float *)z)[i] = ((float *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + a2_index * a3 +
|
||||
a3_index] /
|
||||
y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + b2_index * b3 +
|
||||
((float *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + b2_index * b3 +
|
||||
b3_index];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1,
|
||||
template <class T>
|
||||
__global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3,
|
||||
int c0, int c1, int c2, int c3) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
@ -56,15 +57,43 @@ __global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1,
|
|||
int b1_index = c1_index % b1;
|
||||
int b2_index = c2_index % b2;
|
||||
int b3_index = c3_index % b3;
|
||||
z[i] = pow(x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + a2_index * a3 +
|
||||
a3_index] +
|
||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + b2_index * b3 +
|
||||
b3_index];
|
||||
}
|
||||
}
|
||||
__global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3,
|
||||
int c0, int c1, int c2, int c3) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
int n = c0 * c1 * c2 * c3;
|
||||
|
||||
for (int i = index; i < n; i += stride) {
|
||||
int c0_index = i / (c1 * c2 * c3);
|
||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
||||
|
||||
int a0_index = c0_index % a0;
|
||||
int a1_index = c1_index % a1;
|
||||
int a2_index = c2_index % a2;
|
||||
int a3_index = c3_index % a3;
|
||||
|
||||
int b0_index = c0_index % b0;
|
||||
int b1_index = c1_index % b1;
|
||||
int b2_index = c2_index % b2;
|
||||
int b3_index = c3_index % b3;
|
||||
((float *)z)[i] = pow(((float *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||
a2_index * a3 + a3_index],
|
||||
y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||
((float *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||
b2_index * b3 + b3_index]);
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
|
||||
|
@ -74,7 +103,17 @@ void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
|||
_div_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||
b3, c0, c1, c2, c3);
|
||||
}
|
||||
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_add_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||
b3, c0, c1, c2, c3);
|
||||
}
|
||||
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
int blocksize = block_work_size();
|
||||
|
|
|
@ -16,8 +16,9 @@ class PadSliceCudaCompute {
|
|||
metadata.partNDim[i] = partTensor->getDims()[i];
|
||||
metadata.partStride[i] = partTensor->getStride()[i];
|
||||
}
|
||||
pad_slice_kernel(partTensor->getRawDataPtr<float *>(),
|
||||
wholeTensor->getRawDataPtr<float *>(), metadata, nDims,
|
||||
metadata.DType = partTensor->getDType().getIndex();
|
||||
pad_slice_kernel(partTensor->getRawDataPtr<void *>(),
|
||||
wholeTensor->getRawDataPtr<void *>(), metadata, nDims,
|
||||
wholeTensor->size(), isPad);
|
||||
}
|
||||
};
|
||||
|
@ -40,6 +41,8 @@ class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig {
|
|||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda,
|
||||
"Slice__CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Int64, SliceCuda,
|
||||
"Slice__CUDA_Int64");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda,
|
||||
"Pad__CUDA_Float32");
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_pad_slice.h"
|
||||
#include "core/data_type.h"
|
||||
|
||||
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
||||
TransMetaData metaData,
|
||||
|
@ -19,7 +20,8 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
|||
return offset;
|
||||
}
|
||||
|
||||
__global__ void _pad_slice_kernel(float *part, float *whole,
|
||||
template <typename T>
|
||||
__global__ void _pad_slice_kernel(T *part, T *whole,
|
||||
TransMetaData metaData, int nDims, int num,
|
||||
bool isPad) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
@ -41,12 +43,17 @@ __global__ void _pad_slice_kernel(float *part, float *whole,
|
|||
}
|
||||
|
||||
namespace infini {
|
||||
void pad_slice_kernel(float *partData, float *wholeData,
|
||||
void pad_slice_kernel(void *partData, void *wholeData,
|
||||
const TransMetaData &metadata, int nDims, int num,
|
||||
bool isPad) {
|
||||
int blockSize = 32 * 16;
|
||||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
_pad_slice_kernel<<<gridSize, blockSize>>>(partData, wholeData, metadata,
|
||||
if (metadata.DType == DataType::Int64.getIndex()) {
|
||||
_pad_slice_kernel<int64_t><<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData, metadata,
|
||||
nDims, num, isPad);
|
||||
} else if (metadata.DType == DataType::Float32.getIndex()) {
|
||||
_pad_slice_kernel<float><<<gridSize, blockSize>>>((float*)partData, (float*)wholeData, metadata,
|
||||
nDims, num, isPad);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -13,6 +13,8 @@ class CopyCuda : public CudaKernelWithoutConfig {
|
|||
// reshape/flatten/identity all act as copying from input to output.
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda,
|
||||
"Reshape_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda,
|
||||
"Reshape_CUDA_Int32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda,
|
||||
"Flatten_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda,
|
||||
|
|
|
@ -46,10 +46,10 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
for (size_t i = 0; i < size; ++i)
|
||||
if (auto _i = axes.find(i); _i != axes.end()) {
|
||||
auto __i = _i->second;
|
||||
auto start = starts[__i];
|
||||
auto end = ends[__i];
|
||||
this->axes.push_back({start >= 0 ? start : start + shape[__i],
|
||||
end >= 0 ? end : end + shape[__i],
|
||||
auto start = starts[__i] >= 0 ? starts[__i] : starts[__i] + shape[i];
|
||||
auto end = ends[__i] >= 0 ? ends[__i] : ends[__i] + shape[i];
|
||||
this->axes.push_back({start,
|
||||
end,
|
||||
steps[__i]});
|
||||
} else {
|
||||
this->axes.push_back({0, shape[i], 1});
|
||||
|
|
|
@ -6,8 +6,8 @@ TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
: OperatorObj(OpType::Transpose, {input}, {output}) {
|
||||
auto rank = input->getRank();
|
||||
if (permute.empty()) {
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
transposePermute[i] = i;
|
||||
for (size_t i = rank - 1; i >= 0; --i) {
|
||||
transposePermute.emplace_back(i);
|
||||
}
|
||||
} else {
|
||||
IT_ASSERT(rank == permute.size());
|
||||
|
|
|
@ -1,3 +1,18 @@
|
|||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/split.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -41,4 +56,168 @@ int get_real_axis(const int &axis, const int &rank) {
|
|||
}
|
||||
return newAxis;
|
||||
}
|
||||
|
||||
void addOperatorFromGraphTopo(GraphObj &g,
|
||||
std::shared_ptr<refactor::computation::Operator> nodeInfo,
|
||||
std::vector<size_t> input, std::vector<size_t> output,
|
||||
std::unordered_map<size_t, Tensor> &edgeToTensor,
|
||||
std::vector<refactor::computation::Edge> edges) {
|
||||
std::string name(nodeInfo->opType.name());
|
||||
auto attr = nodeInfo->attributes;
|
||||
#define ELSE_IF(op) \
|
||||
else if (name == "onnx::op") { \
|
||||
g.addOpWithOutputs<op##Obj>(edgeToTensor[input[0]], edgeToTensor[output[0]]); \
|
||||
}
|
||||
if (name == "onnx::Conv") {
|
||||
// auto p = attr["pads"].ints();
|
||||
// auto s = attr["strides"].ints();
|
||||
// auto d = attr["dilations"].ints();
|
||||
// g.addOpWithOutputs<ConvObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]], p[0], p[1], s[0], s[1], d[0], d[1]);
|
||||
} else if (name == "onnx::Add") {
|
||||
g.addOpWithOutputs<AddObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::AveragePool") {
|
||||
// auto p = attr["pads"].ints();
|
||||
// auto s = attr["strides"].ints();
|
||||
// auto d = attr["dilations"].ints();
|
||||
// int h = edgeToTensor[input[0]]->getDims()[2];
|
||||
// int w = edgeToTensor[input[0]]->getDims()[3];
|
||||
// g.addOpWithOutputs<AvgPoolObj>(edgeToTensor[input[0]], edgeToTensor[output[0]], h, w,
|
||||
// d[0], d[1], p[0], p[1], s[0], s[1]);
|
||||
} else if (name == "onnx::Reshape") {
|
||||
IT_ASSERT(input.size() == 2);
|
||||
auto shapeValue = reinterpret_cast<int64_t *>(edges[input[1]].tensor->data->ptr);
|
||||
auto rank = edgeToTensor[input[1]]->getDims()[0];
|
||||
Shape shape(rank);
|
||||
for (size_t i = 0; i < (size_t)rank; ++i) {
|
||||
shape[i] = static_cast<int>(*(shapeValue + i));
|
||||
}
|
||||
g.addOpWithOutputs<ReshapeObj>(edgeToTensor[input[0]], edgeToTensor[output[0]], shape);
|
||||
} else if (name == "onnx::Gemm") {
|
||||
auto alpha = attr.find("alpha") != attr.end() ? attr["alpha"].float_() : 1.0;
|
||||
auto beta = attr.find("beta") != attr.end() ? attr["beta"].float_() : 1.0;
|
||||
auto transA = attr.find("transA") != attr.end() ? attr["transA"].int_() : 0;
|
||||
auto transB = attr.find("transB") != attr.end() ? attr["transB"].int_() : 0;
|
||||
IT_ASSERT(alpha == 1.0);
|
||||
IT_ASSERT(beta == 1.0);
|
||||
g.addOpWithOutputs<MatmulObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]], transA, transB,
|
||||
input.size() > 2 ? edgeToTensor[input[2]] : nullptr, ActType::None);
|
||||
} else if (name == "onnx::Pow") {
|
||||
g.addOpWithOutputs<PowerObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Gather") {
|
||||
auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : 0;
|
||||
g.addOpWithOutputs<GatherObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]], axis);
|
||||
} else if (name == "onnx::Max") {
|
||||
g.addOpWithOutputs<MaximumObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Div") {
|
||||
g.addOpWithOutputs<DivObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Mul") {
|
||||
g.addOpWithOutputs<MulObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Sub") {
|
||||
g.addOpWithOutputs<SubObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]]);
|
||||
} else if (name == "onnx::Slice") {
|
||||
auto startValue = reinterpret_cast<int64_t *>(edges[input[1]].tensor->data->ptr);
|
||||
auto startRank = edgeToTensor[input[1]]->getRank();
|
||||
auto endValue = reinterpret_cast<int64_t *>(edges[input[2]].tensor->data->ptr);
|
||||
auto endRank = edgeToTensor[input[2]]->getRank();
|
||||
std::vector<int> start, end, axesVal, stepsVal;
|
||||
std::optional<std::vector<int>> axes, steps;
|
||||
if (input.size() > 3) {
|
||||
auto axesValue = reinterpret_cast<int64_t *>(edges[input[3]].tensor->data->ptr);
|
||||
auto axesRank = edgeToTensor[input[3]]->getRank();
|
||||
for (size_t i = 0; i < axesRank; ++i) {
|
||||
axesVal.emplace_back(static_cast<int>(*(axesValue + i)));
|
||||
}
|
||||
axes = axesVal;
|
||||
}
|
||||
if (input.size() > 4) {
|
||||
auto stepsValue = reinterpret_cast<int64_t *>(edges[input[4]].tensor->data->ptr);
|
||||
auto stepsRank = edgeToTensor[input[4]]->getRank();
|
||||
for (size_t i = 0; i < stepsRank; ++i) {
|
||||
stepsVal.emplace_back(static_cast<int>(*(stepsValue + i)));
|
||||
}
|
||||
steps = stepsVal;
|
||||
}
|
||||
for (size_t i = 0; i < startRank; ++i) {
|
||||
int64_t startVal = *(startValue + i);
|
||||
if (axes.has_value()) {
|
||||
startVal = std::min(startVal, static_cast<int64_t>(edgeToTensor[input[0]]->getDims()[axes.value()[i]]));
|
||||
} else {
|
||||
startVal = std::min(startVal, static_cast<int64_t>(edgeToTensor[input[0]]->getDims()[i]));
|
||||
}
|
||||
start.emplace_back(static_cast<int>(startVal));
|
||||
}
|
||||
for (size_t i = 0; i < endRank; ++i) {
|
||||
int64_t endVal = *(endValue + i);
|
||||
if (axes.has_value()) {
|
||||
endVal = std::min(endVal, static_cast<int64_t>(edgeToTensor[input[0]]->getDims()[axes.value()[i]]));
|
||||
} else {
|
||||
endVal = std::min(endVal, static_cast<int64_t>(edgeToTensor[input[0]]->getDims()[i]));
|
||||
}
|
||||
end.emplace_back(static_cast<int>(endVal));
|
||||
}
|
||||
g.addOpWithOutputs<SliceObj>(edgeToTensor[input[0]], edgeToTensor[output[0]], start, end,
|
||||
axes, steps);
|
||||
} else if (name == "onnx::Softmax") {
|
||||
auto axis = attr.find("axis") != attr.end() ? attr["axis"].int_() : -1;
|
||||
g.addOpWithOutputs<SoftmaxObj>(edgeToTensor[input[0]], edgeToTensor[output[0]], axis);
|
||||
} else if (name == "onnx::ReduceMean") {
|
||||
auto keepdims = attr.find("keepdims") != attr.end() ? attr["keepdims"].int_() : 1;
|
||||
std::vector<int> axesVal;
|
||||
std::optional<std::vector<int>> axes;
|
||||
if (input.size() > 1) {
|
||||
auto axesValue = reinterpret_cast<int64_t *>(edges[input[1]].tensor->data->ptr);
|
||||
auto axesRank = edgeToTensor[input[1]]->getRank();
|
||||
for (size_t i = 0; i < axesRank; ++i) {
|
||||
axesVal.emplace_back(static_cast<int>(*(axesValue + i)));
|
||||
}
|
||||
axes = axesVal;
|
||||
}
|
||||
g.addOpWithOutputs<ReduceMeanObj>(edgeToTensor[input[0]], edgeToTensor[output[0]], axes, keepdims);
|
||||
} else if (name == "onnx::Concat") {
|
||||
auto axis = attr["axis"].int_();
|
||||
std::vector<Tensor> inputs;
|
||||
for (auto i : input) {
|
||||
inputs.emplace_back(edgeToTensor[i]);
|
||||
}
|
||||
g.addOpWithOutputs<ConcatObj>(inputs, edgeToTensor[output[0]], axis);
|
||||
} else if (name == "onnx::MatMul") {
|
||||
g.addOpWithOutputs<MatmulObj>(edgeToTensor[input[0]], edgeToTensor[input[1]], edgeToTensor[output[0]], false, false, nullptr, ActType::None);
|
||||
} else if (name == "onnx::Transpose") {
|
||||
int rank = edgeToTensor[input[0]]->getRank();
|
||||
std::vector<int> permDefault;
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
permDefault.emplace_back(i);
|
||||
}
|
||||
std::vector<int> perm;
|
||||
if (attr.find("perm") != attr.end()) {
|
||||
auto permAttr = attr["perm"].ints();
|
||||
for (auto e : permAttr) {
|
||||
perm.emplace_back(static_cast<int>(e));
|
||||
}
|
||||
} else {
|
||||
perm = permDefault;
|
||||
}
|
||||
g.addOpWithOutputs<TransposeObj>(edgeToTensor[input[0]], edgeToTensor[output[0]], perm);
|
||||
}
|
||||
ELSE_IF(Relu)
|
||||
ELSE_IF(Sqrt)
|
||||
ELSE_IF(Identity)
|
||||
|
||||
#undef ELSE_IF
|
||||
}
|
||||
|
||||
void addEdgeToTensor(GraphObj &g, size_t index,
|
||||
std::shared_ptr<refactor::computation::Tensor> tensor,
|
||||
std::unordered_map<size_t, Tensor> &edgeToTensor,
|
||||
Runtime runtime) {
|
||||
auto refShape = tensor->shape;
|
||||
Shape shape;
|
||||
for (auto ele : refShape) {
|
||||
IT_ASSERT(ele.hasValue());
|
||||
shape.emplace_back(ele.value());
|
||||
}
|
||||
auto dType = tensor->dataType;
|
||||
Tensor tensorInf = g.addTensor(shape, DataType(static_cast<int>(dType)));
|
||||
edgeToTensor.insert(std::make_pair(index, tensorInf));
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue