Add dump interface

This commit is contained in:
panzezhong 2023-09-14 16:35:31 +08:00
parent 4c321c8a91
commit 6ce2111842
7 changed files with 96 additions and 2 deletions

44
include/core/dump.h Normal file
View File

@ -0,0 +1,44 @@
#pragma once
#include "core/common.h"
#include "core/operator.h"
#include "core/runtime.h"
#include "core/tensor.h"
/**
* A Dump stores intermediate states of a model run, and exposes the info to
* outside queries.
*/
namespace infini {
class Dump {
protected:
string opKey;
int location = 0;
vector<Tensor> inputs;
vector<Tensor> outputs;
public:
Dump() {}
/*
* Dump the info of a operator.
*/
void dumpOp(Operator op);
vector<Tensor> getInputs() { return inputs; }
vector<Tensor> getOutputs() { return outputs; }
// TODO: For now, use type name and count to locate a specific operator.
// In the future, use unique name or id of the queried operator.
void setOpQuery(string opKey, int location) {
this->opKey = opKey;
this->location = location;
}
string getOpKey() { return this->opKey; }
/*
* True if op is queried.
*/
bool queriedOp(Operator op, int count = 0);
};
} // namespace infini

View File

@ -1,4 +1,5 @@
#pragma once
#include "core/dump.h"
#include "core/lazy_allocator.h"
#include "core/operator.h"
#include "core/tensor.h"
@ -11,10 +12,11 @@ class GraphObj : public Object {
TensorVec tensors;
OpVec ops;
LazyAllocator allocator;
Dump dump;
public:
explicit GraphObj(Runtime runtime)
: runtime(runtime), allocator(runtime), sorted(false){};
: runtime(runtime), allocator(runtime), dump(), sorted(false){};
GraphObj(Runtime runtime, OpVec ops_in);
string toString() const override;
Runtime getRuntime() const { return runtime; }
@ -40,6 +42,9 @@ class GraphObj : public Object {
tensors.erase(it);
}
void dumpOp(Operator op) { dump.dumpOp(op); }
Dump& getDump() { return dump; }
void deleteConnection(Tensor tensor, Operator op);
void addConnection(Tensor tensor, Operator op);
void replaceConnection(Tensor oldInput, Tensor newInput, Operator op);

View File

@ -16,6 +16,8 @@ class GraphHandlerObj {
Tensor tensor(Shape dims, int dtype);
Dump &getDump() { return g->getDump(); }
//------ operators
inline OpVec operators() { return g->getOperators(); }

View File

@ -44,6 +44,7 @@ class TensorBaseObj : public Object {
}
DataType getDType() const { return dtype; }
int getDTypeCode() const {return dtype.getIndex();}
Runtime getRuntime() const { return runtime; }
// std::pair<Operator *, int> getOutputOfWithIndex();

22
src/core/dump.cc Normal file
View File

@ -0,0 +1,22 @@
#include "core/dump.h"
namespace infini {
void Dump::dumpOp(Operator op) {
inputs.clear();
outputs.clear();
// Clone the inputs and outputs to host and store in dump
for (Tensor input : op->getInputs()) {
inputs.push_back(input->clone(NativeCpuRuntimeObj::getInstance()));
}
for (Tensor output : op->getOutputs()) {
outputs.push_back(output->clone(NativeCpuRuntimeObj::getInstance()));
}
}
bool Dump::queriedOp(Operator op, int count) {
return strcmp(op->getOpType().toString(), opKey.c_str()) == 0 &&
location == count;
}
} // namespace infini

View File

@ -23,6 +23,9 @@ namespace infini {
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
auto &perfEngine = PerfEngine::getInstance();
#ifdef DEBUG_MODE
std::map<OpType, int> opCnt;
#endif
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
@ -37,6 +40,13 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
kernel->compute(op, this);
}
checkCudaError(cudaGetLastError()) << op->toString();
#ifdef DEBUG_MODE
if (graph->getDump().queriedOp(op, opCnt[op->getOpType()])) {
graph->dumpOp(op);
}
opCnt[op->getOpType()]++;
#endif
}
}

View File

@ -1,4 +1,5 @@
#include "core/data_type.h"
#include "core/dump.h"
#include "core/graph_handler.h"
#include "operators/batch_norm.h"
#include "operators/concat.h"
@ -326,6 +327,7 @@ void init_graph_builder(py::module &m) {
py::buffer_protocol())
.def("fuid", &TensorObj::getFuid, policy::automatic)
.def("shape", &TensorObj::getDims, policy::move)
.def("dtype", &TensorObj::getDTypeCode, policy::automatic)
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
.def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)
.def("copyin_int64", &TensorObj::copyin<int64_t>, policy::move)
@ -381,6 +383,8 @@ void init_graph_builder(py::module &m) {
format = py::format_descriptor<int8_t>::format();
} else if (self.getDType() == DataType::UInt8) {
format = py::format_descriptor<uint8_t>::format();
} else if (self.getDType() == DataType::Bool) {
format = py::format_descriptor<bool>::format();
} else if (self.getDType() == DataType::Float16 ||
self.getDType() == DataType::BFloat16) {
// Python uses "e" for half precision float type code.
@ -460,7 +464,13 @@ void init_graph_builder(py::module &m) {
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
.def("tune", &Handler::tune, policy::automatic)
.def("run", &Handler::run, policy::automatic)
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
.def("getDump", &Handler::getDump, policy::reference);
py::class_<Dump, std::shared_ptr<Dump>>(m, "Dump")
.def("setOpQuery", &Dump::setOpQuery, policy::automatic)
.def("getOutputs", &Dump::getOutputs, policy::move)
.def("getOpKey", &Dump::getOpKey, policy::automatic)
.def("getInputs", &Dump::getInputs, policy::move);
}
} // namespace infini