2023-02-12 08:23:49 +08:00
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "core/graph.h"
|
|
|
|
|
#include "core/runtime.h"
|
2023-02-23 15:29:16 +08:00
|
|
|
|
#include <cstdint>
|
|
|
|
|
#include <iostream>
|
2023-02-12 08:23:49 +08:00
|
|
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
|
|
|
|
|
|
class GraphHandlerObj {
|
|
|
|
|
Graph g;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
GraphHandlerObj(Runtime runtime)
|
|
|
|
|
: g(make_ref<GraphObj>(std::move(runtime))) {}
|
|
|
|
|
|
|
|
|
|
Tensor tensor(Shape dims, int dtype);
|
|
|
|
|
|
2023-02-17 15:00:34 +08:00
|
|
|
|
//------ operators
|
|
|
|
|
|
|
|
|
|
inline OpVec operators() { return g->getOperators(); }
|
|
|
|
|
|
2023-02-22 15:05:44 +08:00
|
|
|
|
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
|
|
|
|
int sh, int sw, int dh, int dw);
|
2023-04-18 15:10:33 +08:00
|
|
|
|
Tensor convTransposed2d(Tensor input, Tensor weight, Tensor output, int ph,
|
|
|
|
|
int pw, int sh, int sw, int dh, int dw, int oph,
|
|
|
|
|
int opw);
|
2023-02-12 08:23:49 +08:00
|
|
|
|
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
|
|
|
|
Tensor bias, ActType act);
|
2023-08-07 11:17:05 +08:00
|
|
|
|
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
|
|
|
|
|
Tensor var, Tensor scale, Tensor bias,
|
|
|
|
|
float momentum, float eps, bool training);
|
2023-02-13 17:15:35 +08:00
|
|
|
|
|
2023-02-14 16:26:47 +08:00
|
|
|
|
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
|
|
|
|
int ph, int pw, int sh, int sw);
|
|
|
|
|
Tensor avgPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
|
|
|
|
int ph, int pw, int sh, int sw);
|
|
|
|
|
|
2023-02-13 11:25:54 +08:00
|
|
|
|
Tensor add(Tensor a, Tensor b, Tensor c);
|
|
|
|
|
Tensor sub(Tensor a, Tensor b, Tensor c);
|
|
|
|
|
Tensor mul(Tensor a, Tensor b, Tensor c);
|
|
|
|
|
Tensor div(Tensor a, Tensor b, Tensor c);
|
|
|
|
|
Tensor pow(Tensor a, Tensor b, Tensor c);
|
2023-02-13 11:54:54 +08:00
|
|
|
|
|
|
|
|
|
Tensor relu(Tensor x, Tensor y);
|
|
|
|
|
Tensor sigmoid(Tensor x, Tensor y);
|
|
|
|
|
Tensor tanh(Tensor x, Tensor y);
|
2023-04-17 12:15:23 +08:00
|
|
|
|
Tensor softmax(Tensor x, Tensor y, int axis);
|
2023-02-13 11:54:54 +08:00
|
|
|
|
Tensor abs(Tensor x, Tensor y);
|
2023-08-18 12:17:47 +08:00
|
|
|
|
Tensor sqrt(Tensor x, Tensor y);
|
2023-04-18 15:10:33 +08:00
|
|
|
|
Tensor shape(Tensor x, Tensor y);
|
2023-02-13 12:13:01 +08:00
|
|
|
|
Tensor identity(Tensor x, Tensor y);
|
2023-04-17 12:15:23 +08:00
|
|
|
|
Tensor flatten(Tensor s, Tensor y, int axis);
|
2023-04-18 15:10:33 +08:00
|
|
|
|
Tensor pRelu(Tensor x, Tensor slope, Tensor y);
|
|
|
|
|
Tensor clip(Tensor x, Tensor y, std::optional<float> min,
|
|
|
|
|
std::optional<float> max);
|
|
|
|
|
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
|
2023-02-14 09:50:32 +08:00
|
|
|
|
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
2023-02-14 13:42:35 +08:00
|
|
|
|
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
2023-04-18 15:10:33 +08:00
|
|
|
|
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
|
|
|
|
int num_outputs);
|
2023-02-14 15:35:01 +08:00
|
|
|
|
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
|
|
|
|
Tensor reduceMean(Tensor data, Tensor reduced,
|
|
|
|
|
const optional<vector<int>> &axes, bool keepdims);
|
2023-02-14 17:35:18 +08:00
|
|
|
|
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
|
|
|
|
|
const vector<int> &ends, const optional<vector<int>> &axes,
|
|
|
|
|
const optional<vector<int>> &steps);
|
2023-02-15 11:41:06 +08:00
|
|
|
|
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
|
|
|
|
const optional<vector<int>> &axes);
|
2023-08-16 21:49:43 +08:00
|
|
|
|
Tensor cast(Tensor input, Tensor output, int to);
|
2023-02-23 11:08:00 +08:00
|
|
|
|
|
2023-02-17 15:00:34 +08:00
|
|
|
|
//------ modifiers
|
|
|
|
|
|
|
|
|
|
inline bool topo_sort() { return g->topo_sort(); }
|
|
|
|
|
|
2023-04-18 15:10:33 +08:00
|
|
|
|
inline void optimize() { g->optimize(); }
|
|
|
|
|
|
2023-02-23 11:08:00 +08:00
|
|
|
|
//------ runtime
|
|
|
|
|
|
|
|
|
|
inline void data_malloc() { g->dataMalloc(); }
|
|
|
|
|
|
2023-06-29 07:15:07 +08:00
|
|
|
|
inline void tune() { g->getRuntime()->run(g, true); }
|
|
|
|
|
|
2023-02-23 11:08:00 +08:00
|
|
|
|
inline void run() { g->getRuntime()->run(g); }
|
2023-06-29 07:15:07 +08:00
|
|
|
|
|
|
|
|
|
inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); }
|
2023-02-12 08:23:49 +08:00
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace infini
|