Fix: GraphObj ctor should use cloneTensor

This commit is contained in:
Liyan Zheng 2023-03-14 21:32:15 +08:00
parent 220b76f355
commit 6441abb88f
2 changed files with 49 additions and 37 deletions

View File

@ -11,10 +11,10 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
for (const auto &op : ops_in) {
for (const auto &t : op->getInputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = t->clone();
tensorPool[t->getFuid()] = cloneTensor(t);
for (const auto &t : op->getOutputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = t->clone();
tensorPool[t->getFuid()] = cloneTensor(t);
}
for (const auto &[_, t] : tensorPool)
addTensor(t);
@ -127,8 +127,12 @@ Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
}
Tensor GraphObj::addTensor(const Tensor &tensor) {
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
return tensors.emplace_back(tensor);
IT_ASSERT(tensor->getRuntime() == runtime,
std::string("Tensor runtime mismatch: cannot add a tenosr in ") +
tensor->getRuntime()->toString() + " to " +
runtime->toString());
tensors.emplace_back(tensor);
return tensor;
}
TensorVec GraphObj::addTensor(const TensorVec &tensors) {

View File

@ -56,15 +56,12 @@ TEST(Mutator, NaiveConvWithInterpreter) {
// FIXME: failed since implicit transpose for DLT
TEST(Mutator, InfoGAN_TConv_3_correctness) {
// verifyNaiveMembound True: subgraph after transformation
// verifyNaiveMembound False: subgraph of one single membound (eOP)
// const bool verifyNaiveMembound = false;
const bool verifyMutator = true;
Runtime runtime = make_ref<CudaRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
Graph gCpu = make_ref<GraphObj>(cpu);
// TODO: recover me for InfoGAN
const int n = 1, c = 256, h = 2, w = 2, f = 448, r = 4, s = 4;
// // Minimum config for test
// const int n = 1, c = 1, h = 2, w = 2, f = 1, r = 4, s = 4;
@ -76,39 +73,50 @@ TEST(Mutator, InfoGAN_TConv_3_correctness) {
auto mutator =
make_ref<NMutator>(NMutator::Mode::RuleBased,
vector<int>{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90});
// mutator->setToNaiveMembound();
SearchEngine searchEngine(runtime, mutator);
auto bestGraph = searchEngine.run(g);
bestGraph->print();
printf("--- SearchEngine Finished ---\n");
vector<Graph> bestGraphs;
if (verifyMutator) {
SearchEngine searchEngine(runtime, mutator);
bestGraphs.emplace_back(searchEngine.run(g));
printf("--- SearchEngine Finished ---\n");
} else {
bestGraphs = mutator->run(g);
}
g->dataMalloc();
bestGraph->dataMalloc();
for (auto t : g->getInputs()) {
t->setData(IncrementalGenerator());
}
for (auto t : bestGraph->getInputs()) {
t->setData(IncrementalGenerator());
}
for (auto t : g->getOutputs()) {
t->setData(IncrementalGenerator());
}
for (auto t : bestGraph->getOutputs()) {
t->setData(IncrementalGenerator());
}
runtime->run(g);
// puts("cuDNN");
// g->getOutputs()[0]->printData();
runtime->run(bestGraph);
// puts("Output");
// bestGraph->getOutputs()[0]->printData();
for (auto bestGraph : bestGraphs) {
puts("====== New best graph");
bestGraph->print();
auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
g->dataMalloc();
bestGraph->dataMalloc();
for (auto t : g->getInputs()) {
t->setData(IncrementalGenerator());
}
for (auto t : bestGraph->getInputs()) {
t->setData(IncrementalGenerator());
}
for (auto t : g->getOutputs()) {
t->setData(IncrementalGenerator());
}
for (auto t : bestGraph->getOutputs()) {
t->setData(IncrementalGenerator());
}
const bool tuningOp = true; // FIXME: some eOps are not tuned
runtime->run(g);
// puts("cuDNN");
// g->getOutputs()[0]->printData();
runtime->run(bestGraph, tuningOp);
// puts("Output");
// bestGraph->getOutputs()[0]->printData();
EXPECT_TRUE(go0->equalData(bgo0));
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
bestGraph->getOutputs()[0]->getRawDataPtr<void *>());
auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
EXPECT_TRUE(go0->equalData(bgo0));
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
bestGraph->getOutputs()[0]->getRawDataPtr<void *>());
}
}
// TEST(Mutator, Conv9x9) {