#pragma once #include "core/runtime.h" #include "core/workspace.h" #include "kunlun/kunlun_common.h" #ifdef INFINI_USE_XCCL #include "kunlun/xccl_communicator.h" #endif namespace infini { class KUNLUNRuntimeObj : public RuntimeObj { private: xdnn::Context *ctx; std::unique_ptr comm; // KUNLUNPtr workspace; // size_t workspaceSize; Workspace workspace; public: KUNLUNRuntimeObj(int deviceId = 0) : RuntimeObj(Device::KUNLUN) { xpu_set_device(deviceId); ctx = xdnn::create_context(); // 10GB for Longformer // size_t longformerNum = 3lu * (1 << 30); size_t workspaceSize = 2llu << 30; // 2 GB KUNLUNPtr wkspacePtr = alloc(workspaceSize); workspace = make_ref>(wkspacePtr, workspaceSize); } virtual ~KUNLUNRuntimeObj() { KUNLUNPtr wkspacePtr = workspace->getWorkspace(); dealloc(wkspacePtr); xdnn::destroy_context(ctx); } string toString() const override; void run(const Graph &graph, bool tune = false, bool profiling = false) const; // double runEvaluation(const Graph &graph, int nWarmups, // int nEvaluations) const; void sync() const; KUNLUNPtr alloc(size_t size) override { void *ptr; checkKUNLUNError( xpu_malloc((void **)&ptr, size, XPUMemoryKind::XPU_MEM_HBM)); return ptr; } void dealloc(void *ptr) override { xpu_free(ptr); } xdnn::Context *KUNLUNHandle() const { return ctx; } // Get $size workspace by bytes KUNLUNPtr getWorkspace(size_t size) const { auto ret = workspace->getWorkspace(size); return ret; } Workspace getWorkspaceObj() const { return workspace; } void copyBlobFromCPU(void *dst, const void *src, size_t bytes) const override { xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_HOST_TO_DEVICE); } void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override { xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_DEVICE_TO_HOST); } void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const override { xpu_memcpy(dst, const_cast(src), bytes, XPUMemcpyKind::XPU_DEVICE_TO_DEVICE); } void initComm(const string &name, int worldSize, int rank) final; CommunicatorObj &getCommunicator() const final { return *comm; } private: void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; }; } // namespace infini