2022-09-22 16:57:39 +08:00
|
|
|
#pragma once
|
|
|
|
#include "bang/bang_common.h"
|
|
|
|
#include "core/runtime.h"
|
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
|
|
|
|
class BangRuntimeObj : public RuntimeObj {
|
|
|
|
private:
|
|
|
|
cnnlHandle_t cnnl;
|
2024-01-03 13:28:03 +08:00
|
|
|
cnrtQueue_t queue;
|
|
|
|
std::unique_ptr<CommunicatorObj> comm;
|
2022-09-22 16:57:39 +08:00
|
|
|
BangPtr workspace;
|
|
|
|
size_t workspaceSize;
|
2023-11-24 15:16:25 +08:00
|
|
|
mutable size_t cursor;
|
2022-09-22 16:57:39 +08:00
|
|
|
|
|
|
|
public:
|
2024-01-03 13:28:03 +08:00
|
|
|
explicit BangRuntimeObj(int deviceId = 0)
|
|
|
|
: RuntimeObj(Device::BANG, deviceId) {
|
2023-03-29 15:47:32 +08:00
|
|
|
cnInit(0);
|
|
|
|
CNdev dev;
|
2024-01-03 13:28:03 +08:00
|
|
|
cnDeviceGet(&dev, deviceId);
|
2023-03-29 15:47:32 +08:00
|
|
|
checkBangError(cnrtSetDevice(dev));
|
|
|
|
checkBangError(cnrtQueueCreate(&queue));
|
2022-09-22 16:57:39 +08:00
|
|
|
|
|
|
|
checkCnnlError(cnnlCreate(&cnnl));
|
|
|
|
checkCnnlError(cnnlSetQueue(cnnl, queue));
|
|
|
|
// 10GB for Longformer
|
|
|
|
// size_t longformerNum = 3lu * (1 << 30);
|
|
|
|
workspaceSize = 7ll << 30; // 7 GB
|
2023-11-24 15:16:25 +08:00
|
|
|
cursor = 0;
|
2022-09-22 16:57:39 +08:00
|
|
|
workspace = alloc(workspaceSize);
|
|
|
|
}
|
|
|
|
virtual ~BangRuntimeObj() {
|
|
|
|
dealloc(workspace);
|
2024-01-03 13:28:03 +08:00
|
|
|
checkBangError(cnrtQueueDestroy(queue));
|
2022-09-22 16:57:39 +08:00
|
|
|
checkCnnlError(cnnlDestroy(cnnl));
|
|
|
|
}
|
2023-03-29 13:52:56 +08:00
|
|
|
string toString() const override;
|
2022-09-22 16:57:39 +08:00
|
|
|
|
|
|
|
void run(const Graph &graph, bool tune = false,
|
|
|
|
bool profiling = false) const;
|
|
|
|
// double runEvaluation(const Graph &graph, int nWarmups,
|
|
|
|
// int nEvaluations) const;
|
|
|
|
void sync() const;
|
|
|
|
BangPtr alloc(size_t size) override {
|
|
|
|
void *ptr;
|
|
|
|
checkBangError(cnrtMalloc(&ptr, size));
|
|
|
|
return ptr;
|
|
|
|
}
|
|
|
|
void dealloc(void *ptr) override { checkBangError(cnrtFree(ptr)); }
|
|
|
|
cnnlHandle_t cnnlHandle() const { return cnnl; }
|
|
|
|
BangPtr getWorkspace(size_t size) const {
|
2023-11-24 15:16:25 +08:00
|
|
|
IT_ASSERT((cursor + size) <= workspaceSize);
|
|
|
|
cursor += size;
|
|
|
|
void *temp = workspace;
|
|
|
|
temp += (cursor - size);
|
|
|
|
return temp;
|
2022-09-22 16:57:39 +08:00
|
|
|
}
|
|
|
|
|
2023-11-24 15:16:25 +08:00
|
|
|
void resetWorkspace() const { cursor = 0; }
|
|
|
|
|
2022-09-30 11:01:52 +08:00
|
|
|
void copyBlobFromCPU(void *dst, const void *src,
|
|
|
|
size_t bytes) const override {
|
|
|
|
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
|
|
|
CNRT_MEM_TRANS_DIR_HOST2DEV));
|
2022-09-22 16:57:39 +08:00
|
|
|
}
|
|
|
|
|
2022-09-30 11:01:52 +08:00
|
|
|
void copyBlobToCPU(void *dst, const void *src,
|
|
|
|
size_t bytes) const override {
|
|
|
|
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
|
|
|
CNRT_MEM_TRANS_DIR_DEV2HOST));
|
2022-09-22 16:57:39 +08:00
|
|
|
}
|
|
|
|
|
2022-09-30 11:01:52 +08:00
|
|
|
void copyBlobInsideRuntime(void *dst, const void *src,
|
2022-09-22 16:57:39 +08:00
|
|
|
size_t bytes) const override {
|
2022-09-30 11:01:52 +08:00
|
|
|
checkBangError(cnrtMemcpy(dst, const_cast<void *>(src), bytes,
|
|
|
|
CNRT_MEM_TRANS_DIR_PEER2PEER));
|
2022-09-22 16:57:39 +08:00
|
|
|
}
|
2024-01-03 13:28:03 +08:00
|
|
|
void initComm(const string &name, int worldSize, int rank) final;
|
|
|
|
CommunicatorObj &getCommunicator() const override { return *comm; }
|
|
|
|
cnrtQueue_t getBangQueue() const { return queue; }
|
2023-09-19 14:10:39 +08:00
|
|
|
|
2022-09-22 16:57:39 +08:00
|
|
|
private:
|
|
|
|
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace infini
|