2022-08-22 15:01:03 +08:00
|
|
|
#pragma once
|
2022-08-23 16:55:59 +08:00
|
|
|
#include "core/common.h"
|
2023-09-05 09:47:35 +08:00
|
|
|
#include "core/communicator.h"
|
2023-08-07 11:17:05 +08:00
|
|
|
#include "core/op_type.h"
|
2022-08-23 16:55:59 +08:00
|
|
|
#include "core/ref.h"
|
|
|
|
#include <memory>
|
2022-09-22 15:34:34 +08:00
|
|
|
|
2022-08-22 15:01:03 +08:00
|
|
|
namespace infini {
|
|
|
|
|
2022-08-23 16:55:59 +08:00
|
|
|
/***************** Forward declaration begin *****************/
|
|
|
|
class TensorBaseObj;
|
|
|
|
class TensorObj;
|
|
|
|
class OperatorObj;
|
|
|
|
class GraphObj;
|
2023-02-12 08:23:49 +08:00
|
|
|
class GraphHandlerObj;
|
2022-08-23 16:55:59 +08:00
|
|
|
class RuntimeObj;
|
|
|
|
class BlobObj;
|
|
|
|
|
|
|
|
using TensorBase = Ref<TensorBaseObj>;
|
|
|
|
using Tensor = Ref<TensorObj>;
|
|
|
|
using Operator = Ref<OperatorObj>;
|
|
|
|
using Graph = Ref<GraphObj>;
|
2023-02-12 08:23:49 +08:00
|
|
|
using GraphHandler = Ref<GraphHandlerObj>;
|
2022-08-23 16:55:59 +08:00
|
|
|
using Runtime = Ref<RuntimeObj>;
|
|
|
|
using Blob = Ref<BlobObj>;
|
|
|
|
|
|
|
|
using TensorVec = vector<Tensor>;
|
|
|
|
using OpVec = vector<Operator>;
|
2023-04-17 13:09:07 +08:00
|
|
|
using OpLists = list<Operator>;
|
2022-08-23 16:55:59 +08:00
|
|
|
|
|
|
|
using VType = uint32_t;
|
|
|
|
|
2023-10-16 10:57:08 +08:00
|
|
|
enum class Device { CPU = 1, CUDA, BANG, INTELCPU, KUNLUN };
|
2022-08-23 16:55:59 +08:00
|
|
|
/***************** Forward declaration end *****************/
|
|
|
|
|
2022-08-22 15:01:03 +08:00
|
|
|
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|
|
|
protected:
|
|
|
|
Device device;
|
2023-09-05 09:47:35 +08:00
|
|
|
int deviceId;
|
2022-08-22 15:01:03 +08:00
|
|
|
|
|
|
|
public:
|
2023-09-05 09:47:35 +08:00
|
|
|
explicit RuntimeObj(Device device, int deviceId = 0)
|
|
|
|
: device(device), deviceId(deviceId) {}
|
2022-08-22 15:01:03 +08:00
|
|
|
RuntimeObj(RuntimeObj &other) = delete;
|
|
|
|
RuntimeObj &operator=(RuntimeObj const &) = delete;
|
|
|
|
virtual ~RuntimeObj() {}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @brief Execute a graph.
|
|
|
|
*
|
|
|
|
* @param graph
|
|
|
|
* @param tune If there is no performance record, whether to tune it. These
|
|
|
|
* can be independent method.
|
|
|
|
* @param profiling Whether to print breakdown of time
|
|
|
|
*/
|
|
|
|
virtual void run(const Graph &graph, bool tune = false,
|
|
|
|
bool profiling = false) const = 0;
|
|
|
|
virtual void *alloc(size_t size) = 0;
|
|
|
|
virtual void dealloc(void *ptr) = 0;
|
|
|
|
/**
|
|
|
|
* @brief Get the execution time of each operator in performance record. No
|
|
|
|
* execution happens.
|
|
|
|
*
|
|
|
|
* @param graph
|
|
|
|
* @param profiling Whether to print breakdown of time
|
|
|
|
* @return double Return the sum of perf time for each operator
|
|
|
|
*/
|
|
|
|
double getPerfTime(const Graph &graph, bool profiling = false) const;
|
|
|
|
Blob allocBlob(size_t size);
|
2023-03-27 21:28:49 +08:00
|
|
|
bool isCpu() const {
|
2023-04-17 12:15:23 +08:00
|
|
|
return device == Device::CPU || device == Device::INTELCPU;
|
2023-03-27 21:28:49 +08:00
|
|
|
}
|
2022-08-23 16:55:59 +08:00
|
|
|
bool isCuda() const { return device == Device::CUDA; }
|
2022-09-22 16:57:39 +08:00
|
|
|
bool isBang() const { return device == Device::BANG; }
|
2023-10-16 10:57:08 +08:00
|
|
|
bool isKUNLUN() const { return device == Device::KUNLUN; }
|
2022-08-23 16:55:59 +08:00
|
|
|
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
|
2022-09-22 18:06:45 +08:00
|
|
|
// TODO: unify these copy APIs
|
|
|
|
virtual void copyBlobFromCPU(void *dst, const void *src,
|
|
|
|
size_t bytes) const = 0;
|
2022-09-29 14:44:20 +08:00
|
|
|
virtual void copyBlobToCPU(void *dst, const void *src,
|
|
|
|
size_t bytes) const = 0;
|
2022-10-07 16:03:11 +08:00
|
|
|
virtual string toString() const = 0;
|
2022-08-22 15:01:03 +08:00
|
|
|
|
2023-09-05 09:47:35 +08:00
|
|
|
int getDeviceId() const { return deviceId; }
|
|
|
|
|
|
|
|
virtual void initComm(const string &name, int worldSize, int rank) = 0;
|
|
|
|
|
|
|
|
virtual CommunicatorObj &getCommunicator() const = 0;
|
|
|
|
|
2022-08-22 15:01:03 +08:00
|
|
|
protected:
|
|
|
|
void printProfilingData(double totTime,
|
|
|
|
const std::map<OpType, double> &opTime,
|
|
|
|
const std::map<OpType, int> &opCnt) const;
|
2022-09-22 18:06:45 +08:00
|
|
|
virtual void copyBlobInsideRuntime(void *dst, const void *src,
|
2022-08-23 16:55:59 +08:00
|
|
|
size_t bytes) const = 0;
|
2022-08-22 15:01:03 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
class CpuRuntimeObj : public RuntimeObj {
|
|
|
|
public:
|
2023-03-27 21:28:49 +08:00
|
|
|
CpuRuntimeObj(Device dev) : RuntimeObj(dev) {}
|
2022-08-22 15:01:03 +08:00
|
|
|
|
|
|
|
void run(const Graph &graph, bool tune = false,
|
|
|
|
bool profiling = false) const override;
|
2022-08-23 16:55:59 +08:00
|
|
|
|
2022-09-22 18:06:45 +08:00
|
|
|
void copyBlobFromCPU(void *dst, const void *src,
|
|
|
|
size_t bytes) const override;
|
|
|
|
void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override;
|
|
|
|
void copyBlobInsideRuntime(void *dst, const void *src,
|
2022-08-23 16:55:59 +08:00
|
|
|
size_t bytes) const override;
|
2023-09-05 09:47:35 +08:00
|
|
|
void initComm(const string &, int, int) override { IT_TODO_HALT(); }
|
|
|
|
|
|
|
|
CommunicatorObj &getCommunicator() const override { IT_TODO_HALT(); }
|
2023-03-27 21:28:49 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
class NativeCpuRuntimeObj : public CpuRuntimeObj {
|
|
|
|
public:
|
|
|
|
NativeCpuRuntimeObj() : CpuRuntimeObj(Device::CPU) {}
|
|
|
|
|
|
|
|
static Ref<NativeCpuRuntimeObj> &getInstance() {
|
|
|
|
static Ref<NativeCpuRuntimeObj> instance =
|
|
|
|
make_ref<NativeCpuRuntimeObj>();
|
|
|
|
return instance;
|
|
|
|
}
|
|
|
|
void dealloc(void *ptr) override { return free(ptr); };
|
|
|
|
|
|
|
|
void *alloc(size_t size) override {
|
|
|
|
return calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t),
|
|
|
|
sizeof(uint64_t));
|
|
|
|
};
|
2022-10-07 16:03:11 +08:00
|
|
|
string toString() const override;
|
2022-08-22 15:01:03 +08:00
|
|
|
};
|
|
|
|
|
2022-09-22 16:57:39 +08:00
|
|
|
} // namespace infini
|