forked from jiuyuan/InfiniTensor
Add: perf engine
This commit is contained in:
parent
6c356d5b42
commit
efa966a3e2
|
@ -12,7 +12,7 @@
|
|||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
using std::list;
|
||||
using std::map;
|
||||
using std::nullopt;
|
||||
|
@ -46,8 +46,8 @@ using dtype = float;
|
|||
#define _IT_ASSERT_1(name) _IT_ASSERT_2(name, "");
|
||||
|
||||
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
|
||||
#define IT_TODO_HALT(...) IT_ASSERT(false, "Unimplemented")
|
||||
#define IT_TODO_SKIP(...) puts("Unimplemented " __FILE__ ":" __LINE__)
|
||||
#define IT_TODO_HALT() IT_ASSERT(false, "Unimplemented")
|
||||
#define IT_TODO_SKIP() puts("Unimplemented " __FILE__ ":" __LINE__)
|
||||
|
||||
// Other utilities
|
||||
|
||||
|
@ -56,4 +56,6 @@ template <typename T> auto enum_to_underlying(T e) {
|
|||
return static_cast<std::underlying_type_t<T>>(e);
|
||||
}
|
||||
|
||||
} // namespace it
|
||||
double timeit(const std::function<void()> &func);
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
class GraphNode : public Object {
|
||||
protected:
|
||||
|
@ -32,12 +32,9 @@ class GraphNode : public Object {
|
|||
void updateConnection();
|
||||
void dataMalloc();
|
||||
|
||||
// TODO
|
||||
// bool compute();
|
||||
|
||||
// TODO: move to another class
|
||||
// bool exportOnnx(const char *path);
|
||||
// bool importOnnx(const char *net);
|
||||
};
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -3,22 +3,26 @@
|
|||
#include "core/operator.h"
|
||||
#include "core/tensor.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
enum class Device { CPU = 1, CUDA };
|
||||
struct PerfRecord {
|
||||
double time; // in milliseconds
|
||||
};
|
||||
|
||||
class Kernel {
|
||||
public:
|
||||
Kernel() {}
|
||||
virtual ~Kernel() {}
|
||||
|
||||
virtual void compute(const Operator &op,
|
||||
const PerfRecord &record) const = 0;
|
||||
// This function call compute with a default record.
|
||||
virtual void compute(const Operator &op) const = 0;
|
||||
// Tuning should be idempotent since it is called multiple times.
|
||||
virtual PerfRecord tune(const Operator &op) const = 0;
|
||||
};
|
||||
|
||||
class KernelRegistry {
|
||||
public:
|
||||
using Key = std::tuple<Device, OpType, DataType>;
|
||||
|
||||
public:
|
||||
~KernelRegistry() {
|
||||
for (auto &[k, v] : kernels)
|
||||
|
@ -28,29 +32,29 @@ class KernelRegistry {
|
|||
static KernelRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
bool registerKernel(const Key &key, Kernel *kernel) {
|
||||
bool registerKernel(const KernelAttrs &key, Kernel *kernel) {
|
||||
// TODO: kernels with priority
|
||||
IT_ASSERT(kernels.find(key) == kernels.end(),
|
||||
"Kernel already registered");
|
||||
kernels.emplace(key, kernel);
|
||||
return true;
|
||||
}
|
||||
Kernel *getKernel(Device device, OpType opType, DataType dataType) const {
|
||||
return kernels.at(Key{device, opType, dataType});
|
||||
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
|
||||
return kernels.at(kernelAttrs);
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<Key, Kernel *> kernels;
|
||||
std::map<KernelAttrs, Kernel *> kernels;
|
||||
};
|
||||
|
||||
#define _REGISTER_KERNEL_1(device, opType, dataType, kernel, cnt) \
|
||||
namespace it { \
|
||||
namespace infini { \
|
||||
static const bool _CAT(_register_kernel_, cnt) = \
|
||||
KernelRegistry::getInstance().registerKernel( \
|
||||
KernelRegistry::Key{device, opType, dataType}, new kernel()); \
|
||||
KernelAttrs{device, opType, dataType}, new kernel()); \
|
||||
}
|
||||
|
||||
#define REGISTER_KERNEL(device, opType, dataType, kernel) \
|
||||
_REGISTER_KERNEL_1(device, opType, dataType, kernel, __COUNTER__)
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
#include "core/common.h"
|
||||
#include "ref.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
using GuidBaseType = int;
|
||||
|
||||
|
@ -51,4 +51,4 @@ inline std::ostream &operator<<(std::ostream &os, const Ref<T> &obj) {
|
|||
return os;
|
||||
}
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
#include "core/tensor.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
enum class OpType {
|
||||
Unknown = 0,
|
||||
|
@ -37,9 +37,13 @@ enum class OpType {
|
|||
MemBound = 300,
|
||||
};
|
||||
|
||||
enum class Device { CPU = 1, CUDA };
|
||||
|
||||
using KernelAttrs = std::tuple<Device, OpType, DataType>;
|
||||
|
||||
class OpRegistry {
|
||||
public:
|
||||
std::string getOpName(OpType opType) {
|
||||
static std::string getOpName(OpType opType) {
|
||||
#define FOP(op) \
|
||||
case OpType::op: \
|
||||
return #op
|
||||
|
@ -90,6 +94,16 @@ enum class ActType {
|
|||
Tanh,
|
||||
};
|
||||
|
||||
struct OpAttrs {
|
||||
public:
|
||||
virtual bool operator<(const OpAttrs &rhs) const {
|
||||
IT_ASSERT(typeid(*this) == typeid(rhs), "OpAttrs type mismatch.");
|
||||
// Empty OpAttrs are equal
|
||||
return false;
|
||||
}
|
||||
virtual ~OpAttrs() {}
|
||||
};
|
||||
|
||||
class OperatorNode : public Object {
|
||||
public:
|
||||
protected:
|
||||
|
@ -103,6 +117,7 @@ class OperatorNode : public Object {
|
|||
OperatorNode(OpType opType, TensorVec inputs, TensorVec outputs)
|
||||
: type(opType), inputs(inputs), outputs(outputs) {}
|
||||
virtual vector<Shape> computeShape() const = 0;
|
||||
virtual OpAttrs getOpAttrs() const = 0;
|
||||
|
||||
public: // check Op type
|
||||
bool isLinearOp() const;
|
||||
|
@ -132,13 +147,25 @@ class OperatorNode : public Object {
|
|||
|
||||
class MatmulNode : public OperatorNode {
|
||||
public:
|
||||
struct MatmulArgs {
|
||||
struct MatmulArgs : public OpAttrs {
|
||||
int b, m, n, k;
|
||||
// PET assume a row-major tensor layout. transA=false means default
|
||||
// dims, true means A should be transposed before matmul. This is in
|
||||
// oppsite to column-major BLAS.
|
||||
bool transA, transB;
|
||||
ActType act;
|
||||
|
||||
MatmulArgs(int b, int m, int n, int k, bool transA, bool transB,
|
||||
ActType act)
|
||||
: b(b), m(m), n(n), k(k), transA(transA), transB(transB), act(act) {
|
||||
}
|
||||
|
||||
bool operator<(const OpAttrs &rhsGeneric) {
|
||||
auto rhs = dynamic_cast<const MatmulArgs &>(rhsGeneric);
|
||||
return std::tie(b, m, n, k, transA, transB, act) <
|
||||
std::tie(rhs.b, rhs.m, rhs.n, rhs.k, rhs.transA, rhs.transB,
|
||||
rhs.act);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
@ -162,6 +189,7 @@ class MatmulNode : public OperatorNode {
|
|||
bool getTransB() const { return args.transB; }
|
||||
|
||||
MatmulArgs getArgs() const { return args; }
|
||||
OpAttrs getOpAttrs() const override { return args; }
|
||||
|
||||
private:
|
||||
// Q: whether to check the output? Since we can build an Op first and then
|
||||
|
@ -170,4 +198,4 @@ class MatmulNode : public OperatorNode {
|
|||
bool checkValid(const TensorVec &inputs) const;
|
||||
};
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class PerfEngine {
|
||||
public:
|
||||
using Key = std::pair<KernelAttrs, OpAttrs>;
|
||||
|
||||
private:
|
||||
map<Key, PerfRecord> data;
|
||||
|
||||
public:
|
||||
static PerfEngine &getInstance() {
|
||||
static PerfEngine instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::optional<PerfRecord> getPerfData(const Key &key) {
|
||||
auto it = data.find(key);
|
||||
if (it != data.end()) // find previous evaluating results
|
||||
return data.at(key);
|
||||
else
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void setPerfData(const Key &key, const PerfRecord &record) {
|
||||
IT_ASSERT(data.find(key) == data.end(), "Perf data already exist");
|
||||
data.emplace(key, record);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -3,7 +3,7 @@
|
|||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
template <typename T> using Ref = std::shared_ptr<T>;
|
||||
template <typename T> using WRef = std::weak_ptr<T>;
|
||||
|
@ -32,4 +32,4 @@ std::vector<WRef<T>> get_wref_vec(const std::vector<Ref<T>> &vec) {
|
|||
return wref_vec;
|
||||
}
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -1,24 +1,26 @@
|
|||
#pragma once
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
class RunEngine {
|
||||
private:
|
||||
Device device;
|
||||
|
||||
public:
|
||||
RunEngine(Device device) : device(device) {}
|
||||
~RunEngine() {}
|
||||
|
||||
void run(Graph graph) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
Kernel *kernel = kernelRegistry.getKernel(device, op->getOpType(),
|
||||
DataType::Int32);
|
||||
kernel->compute(op);
|
||||
}
|
||||
}
|
||||
void run(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const;
|
||||
double getPerfTime(const Graph &graph, bool profiling = false) const;
|
||||
|
||||
private:
|
||||
Device device;
|
||||
void printProfilingData(double totTime,
|
||||
const std::map<OpType, double> &opTime,
|
||||
const std::map<OpType, int> &opCnt) const;
|
||||
};
|
||||
} // namespace it
|
||||
|
||||
} // namespace infini
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
#include "core/tensor_base.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
// TODO: how to deal with this
|
||||
using ShapeElem = int;
|
||||
|
@ -177,4 +177,4 @@ class TensorNode : public TensorBaseNode {
|
|||
// void printShape();
|
||||
};
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -2,7 +2,7 @@
|
|||
#include "core/object.h"
|
||||
#include "core/ref.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
// class Tensor;
|
||||
class TensorBaseNode;
|
||||
|
@ -267,4 +267,4 @@ class TensorBaseNode : public Object {
|
|||
// void printShape();
|
||||
};
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -0,0 +1,14 @@
|
|||
#include "core/common.h"
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
|
||||
namespace infini {
|
||||
|
||||
double timeit(const std::function<void()> &func) {
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
func();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
return std::chrono::duration<double, std::milli>(end - start).count();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -1,6 +1,6 @@
|
|||
#include "core/graph.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
void GraphNode::updateConnection() { IT_TODO_HALT(); }
|
||||
|
||||
|
@ -17,4 +17,4 @@ void GraphNode::dataMalloc() {
|
|||
tensor->dataMalloc();
|
||||
}
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -1,6 +1,6 @@
|
|||
#include "core/operator.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
bool OperatorNode::isLinearOp() const {
|
||||
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
|
||||
|
@ -37,13 +37,9 @@ vector<Shape> MatmulNode::computeShape() const {
|
|||
MatmulNode::MatmulNode(Tensor A, Tensor B, Tensor C, bool transA, bool transB,
|
||||
Tensor bias, ActType act)
|
||||
: OperatorNode(OpType::Matmul, {A, B, bias}, {C}),
|
||||
args{.b = A->getDims()[0],
|
||||
.m = transA ? A->getDims()[2] : A->getDims()[1],
|
||||
.n = transB ? B->getDims()[1] : B->getDims()[2],
|
||||
.k = transA ? A->getDims()[1] : A->getDims()[2],
|
||||
.transA = transA,
|
||||
.transB = transB,
|
||||
.act = act} {
|
||||
args(A->getDims()[0], transA ? A->getDims()[2] : A->getDims()[1],
|
||||
transB ? B->getDims()[1] : B->getDims()[2],
|
||||
transA ? A->getDims()[1] : A->getDims()[2], transA, transB, act) {
|
||||
IT_ASSERT(checkValid(inputs));
|
||||
}
|
||||
|
||||
|
@ -78,4 +74,4 @@ bool MatmulNode::checkValid(const TensorVec &inputs) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -0,0 +1,105 @@
|
|||
#include "core/run_enigne.h"
|
||||
#include <chrono>
|
||||
|
||||
namespace infini {
|
||||
|
||||
void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
|
||||
if (!tune && profiling)
|
||||
IT_TODO_HALT();
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto perfEngine = PerfEngine::getInstance();
|
||||
// Statistics
|
||||
double totalTime = 0;
|
||||
std::map<OpType, double> opTime;
|
||||
std::map<OpType, int> opCnt;
|
||||
std::chrono::system_clock::time_point begin, end;
|
||||
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType(), DataType::Int32};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpAttrs()};
|
||||
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||
|
||||
// If no record and disable tuning, run with the default argument
|
||||
if (!perfData && !tune) {
|
||||
kernel->compute(op);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: The copy of record should be eliminated
|
||||
PerfRecord record;
|
||||
// Tune the kernel if there is no record
|
||||
if (!perfData) {
|
||||
record = kernel->tune(op);
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = *perfData;
|
||||
|
||||
if (!profiling) {
|
||||
kernel->compute(op, *perfData);
|
||||
continue;
|
||||
} else {
|
||||
double t = timeit([&]() { kernel->compute(op, *perfData); });
|
||||
op->print();
|
||||
printf(" op_time %lf\n", t);
|
||||
totalTime += t;
|
||||
opTime[op->getOpType()] += t;
|
||||
opCnt[op->getOpType()]++;
|
||||
}
|
||||
}
|
||||
if (profiling)
|
||||
printProfilingData(totalTime, opTime, opCnt);
|
||||
}
|
||||
|
||||
double RunEngine::getPerfTime(const Graph &graph, bool profiling) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto perfEngine = PerfEngine::getInstance();
|
||||
// Statistics
|
||||
double totalTime = 0;
|
||||
std::map<OpType, double> opTime;
|
||||
std::map<OpType, int> opCnt;
|
||||
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType(), DataType::Int32};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpAttrs()};
|
||||
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||
|
||||
PerfRecord record;
|
||||
// Tune the kernel if there is no record
|
||||
if (!perfData) {
|
||||
record = kernel->tune(op);
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = *perfData;
|
||||
|
||||
double t = record.time;
|
||||
totalTime += t;
|
||||
if (profiling) {
|
||||
op->print();
|
||||
printf(" op_time %lf\n", t);
|
||||
opTime[op->getOpType()] += t;
|
||||
opCnt[op->getOpType()]++;
|
||||
}
|
||||
}
|
||||
if (profiling)
|
||||
printProfilingData(totalTime, opTime, opCnt);
|
||||
return totalTime;
|
||||
}
|
||||
|
||||
void RunEngine::printProfilingData(double totalTime,
|
||||
const std::map<OpType, double> &opTime,
|
||||
const std::map<OpType, int> &opCnt) const {
|
||||
printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean");
|
||||
for (const auto &[type, t] : opTime) {
|
||||
printf("%11s %3d %7.3f %7.1f %7.3f\n",
|
||||
OpRegistry::getOpName(type).data(), opCnt.at(type), t,
|
||||
t / totalTime * 100, t / opCnt.at(type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -1,5 +1,5 @@
|
|||
#include <core/tensor.h>
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
TensorNode::TensorNode(const Shape &shape, DataType dtype)
|
||||
: TensorBaseNode(shape.size(), dtype), shape(shape) {}
|
||||
|
@ -87,4 +87,4 @@ bool TensorNode::equalData(const Tensor &rhs) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
}; // namespace it
|
||||
}; // namespace infini
|
|
@ -1,9 +1,9 @@
|
|||
#include <core/tensor_base.h>
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
TensorBaseNode::TensorBaseNode(int dim, DataType dtype)
|
||||
: dim(dim), dtype(dtype) {}
|
||||
|
||||
VType TensorBaseNode::getData(size_t offset) const { return data[offset]; }
|
||||
|
||||
}; // namespace it
|
||||
}; // namespace infini
|
|
@ -1,6 +1,6 @@
|
|||
#include "core/kernel.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class NaiveMatmul : public Kernel {
|
||||
void compute(const Operator &_op) const override {
|
||||
|
@ -14,12 +14,21 @@ template <typename T> class NaiveMatmul : public Kernel {
|
|||
const int M = args.m, N = args.n, K = args.k;
|
||||
for (int i = 0; i < M; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
C[i * N + j] = 0;
|
||||
for (int k = 0; k < K; k++) {
|
||||
C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const PerfRecord &record) const override {
|
||||
compute(op);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op) const override {
|
||||
return PerfRecord{.time = timeit([this, &op]() { compute(op); })};
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Int32,
|
||||
|
@ -27,4 +36,4 @@ REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Int32,
|
|||
REGISTER_KERNEL(Device::CPU, OpType::Matmul, DataType::Float32,
|
||||
NaiveMatmul<float>);
|
||||
|
||||
} // namespace it
|
||||
} // namespace infini
|
|
@ -2,9 +2,9 @@
|
|||
#include "core/run_enigne.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace it {
|
||||
namespace infini {
|
||||
|
||||
TEST(Graph, build) {
|
||||
TEST(Graph, build_and_run) {
|
||||
Graph g = make_ref<GraphNode>();
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32);
|
||||
|
@ -21,4 +21,25 @@ TEST(Graph, build) {
|
|||
EXPECT_TRUE(o0->equalData(ans));
|
||||
}
|
||||
|
||||
} // namespace it
|
||||
TEST(Graph, perf_engine) {
|
||||
Graph g = make_ref<GraphNode>();
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Int32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::Int32);
|
||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::Int32);
|
||||
g->dataMalloc();
|
||||
i0->copyData(vector<VType>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data());
|
||||
w0->copyData(vector<VType>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}.data());
|
||||
g->addOp(make_ref<MatmulNode>(i0, w0, o0));
|
||||
RunEngine(Device::CPU).run(g, true, true);
|
||||
double perfTime = RunEngine(Device::CPU).getPerfTime(g);
|
||||
// The example matmul takes 0.0036ms with one core
|
||||
EXPECT_GT(perfTime, 0);
|
||||
EXPECT_LT(perfTime, 0.01);
|
||||
// check answer
|
||||
auto ans = make_ref<TensorNode>(Shape{1, 2, 4}, DataType::Int32);
|
||||
ans->dataMalloc();
|
||||
ans->copyData(vector<VType>{38, 44, 50, 56, 83, 98, 113, 128}.data());
|
||||
EXPECT_TRUE(o0->equalData(ans));
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue