diff --git a/include/kunlun/kunlun_runtime.h b/include/kunlun/kunlun_runtime.h index 0c175158..b279c4b6 100644 --- a/include/kunlun/kunlun_runtime.h +++ b/include/kunlun/kunlun_runtime.h @@ -22,9 +22,9 @@ class KUNLUNRuntimeObj : public RuntimeObj { // 10GB for Longformer // size_t longformerNum = 3lu * (1 << 30); size_t workspaceSize = 3llu << 30; // 3 GB - KUNLUNPtr wkspacePtr = alloc(workspaceSize); - workspace = - make_ref>(wkspacePtr, workspaceSize); + // KUNLUNPtr wkspacePtr = alloc(workspaceSize); + // workspace = + // make_ref>(wkspacePtr, workspaceSize); } virtual ~KUNLUNRuntimeObj() { KUNLUNPtr wkspacePtr = workspace->getWorkspace(); @@ -33,6 +33,10 @@ class KUNLUNRuntimeObj : public RuntimeObj { } string toString() const override; + void init_workspace(KUNLUNPtr workspacePtr) { + workspace = make_ref>(wkspacePtr, workspaceSize); + } + void run(const Graph &graph, bool tune = false, bool profiling = false) const; // double runEvaluation(const Graph &graph, int nWarmups, diff --git a/src/core/graph.cc b/src/core/graph.cc index ac90344a..cd469e50 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -265,6 +265,11 @@ void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) { tensorToOffset[tensor.get()])); } } + + if (runtime->isKUNLUN()) { + allocator.heapAlloc(runtime->workspaceSize); + runtime->init_workspace(allocator.getHeapPtr()); + } } Tensor GraphObj::cloneKV(Tensor &tensor) {