forked from jiuyuan/InfiniTensor
Fix: GraphObj ctor should use cloneTensor
This commit is contained in:
parent
220b76f355
commit
6441abb88f
|
@ -11,10 +11,10 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
|
||||||
for (const auto &op : ops_in) {
|
for (const auto &op : ops_in) {
|
||||||
for (const auto &t : op->getInputs())
|
for (const auto &t : op->getInputs())
|
||||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||||
tensorPool[t->getFuid()] = t->clone();
|
tensorPool[t->getFuid()] = cloneTensor(t);
|
||||||
for (const auto &t : op->getOutputs())
|
for (const auto &t : op->getOutputs())
|
||||||
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
if (tensorPool.find(t->getFuid()) == tensorPool.end())
|
||||||
tensorPool[t->getFuid()] = t->clone();
|
tensorPool[t->getFuid()] = cloneTensor(t);
|
||||||
}
|
}
|
||||||
for (const auto &[_, t] : tensorPool)
|
for (const auto &[_, t] : tensorPool)
|
||||||
addTensor(t);
|
addTensor(t);
|
||||||
|
@ -127,8 +127,12 @@ Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor GraphObj::addTensor(const Tensor &tensor) {
|
Tensor GraphObj::addTensor(const Tensor &tensor) {
|
||||||
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
|
IT_ASSERT(tensor->getRuntime() == runtime,
|
||||||
return tensors.emplace_back(tensor);
|
std::string("Tensor runtime mismatch: cannot add a tenosr in ") +
|
||||||
|
tensor->getRuntime()->toString() + " to " +
|
||||||
|
runtime->toString());
|
||||||
|
tensors.emplace_back(tensor);
|
||||||
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
|
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
|
||||||
|
|
|
@ -56,15 +56,12 @@ TEST(Mutator, NaiveConvWithInterpreter) {
|
||||||
|
|
||||||
// FIXME: failed since implicit transpose for DLT
|
// FIXME: failed since implicit transpose for DLT
|
||||||
TEST(Mutator, InfoGAN_TConv_3_correctness) {
|
TEST(Mutator, InfoGAN_TConv_3_correctness) {
|
||||||
// verifyNaiveMembound True: subgraph after transformation
|
const bool verifyMutator = true;
|
||||||
// verifyNaiveMembound False: subgraph of one single membound (eOP)
|
|
||||||
// const bool verifyNaiveMembound = false;
|
|
||||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||||
Graph g = make_ref<GraphObj>(runtime);
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||||
|
|
||||||
// TODO: recover me for InfoGAN
|
|
||||||
const int n = 1, c = 256, h = 2, w = 2, f = 448, r = 4, s = 4;
|
const int n = 1, c = 256, h = 2, w = 2, f = 448, r = 4, s = 4;
|
||||||
// // Minimum config for test
|
// // Minimum config for test
|
||||||
// const int n = 1, c = 1, h = 2, w = 2, f = 1, r = 4, s = 4;
|
// const int n = 1, c = 1, h = 2, w = 2, f = 1, r = 4, s = 4;
|
||||||
|
@ -76,39 +73,50 @@ TEST(Mutator, InfoGAN_TConv_3_correctness) {
|
||||||
auto mutator =
|
auto mutator =
|
||||||
make_ref<NMutator>(NMutator::Mode::RuleBased,
|
make_ref<NMutator>(NMutator::Mode::RuleBased,
|
||||||
vector<int>{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90});
|
vector<int>{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90});
|
||||||
|
|
||||||
// mutator->setToNaiveMembound();
|
// mutator->setToNaiveMembound();
|
||||||
SearchEngine searchEngine(runtime, mutator);
|
vector<Graph> bestGraphs;
|
||||||
auto bestGraph = searchEngine.run(g);
|
if (verifyMutator) {
|
||||||
bestGraph->print();
|
SearchEngine searchEngine(runtime, mutator);
|
||||||
printf("--- SearchEngine Finished ---\n");
|
bestGraphs.emplace_back(searchEngine.run(g));
|
||||||
|
printf("--- SearchEngine Finished ---\n");
|
||||||
|
} else {
|
||||||
|
bestGraphs = mutator->run(g);
|
||||||
|
}
|
||||||
|
|
||||||
g->dataMalloc();
|
for (auto bestGraph : bestGraphs) {
|
||||||
bestGraph->dataMalloc();
|
puts("====== New best graph");
|
||||||
for (auto t : g->getInputs()) {
|
bestGraph->print();
|
||||||
t->setData(IncrementalGenerator());
|
|
||||||
}
|
|
||||||
for (auto t : bestGraph->getInputs()) {
|
|
||||||
t->setData(IncrementalGenerator());
|
|
||||||
}
|
|
||||||
for (auto t : g->getOutputs()) {
|
|
||||||
t->setData(IncrementalGenerator());
|
|
||||||
}
|
|
||||||
for (auto t : bestGraph->getOutputs()) {
|
|
||||||
t->setData(IncrementalGenerator());
|
|
||||||
}
|
|
||||||
runtime->run(g);
|
|
||||||
// puts("cuDNN");
|
|
||||||
// g->getOutputs()[0]->printData();
|
|
||||||
runtime->run(bestGraph);
|
|
||||||
// puts("Output");
|
|
||||||
// bestGraph->getOutputs()[0]->printData();
|
|
||||||
|
|
||||||
auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
|
g->dataMalloc();
|
||||||
auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
|
bestGraph->dataMalloc();
|
||||||
|
for (auto t : g->getInputs()) {
|
||||||
|
t->setData(IncrementalGenerator());
|
||||||
|
}
|
||||||
|
for (auto t : bestGraph->getInputs()) {
|
||||||
|
t->setData(IncrementalGenerator());
|
||||||
|
}
|
||||||
|
for (auto t : g->getOutputs()) {
|
||||||
|
t->setData(IncrementalGenerator());
|
||||||
|
}
|
||||||
|
for (auto t : bestGraph->getOutputs()) {
|
||||||
|
t->setData(IncrementalGenerator());
|
||||||
|
}
|
||||||
|
const bool tuningOp = true; // FIXME: some eOps are not tuned
|
||||||
|
runtime->run(g);
|
||||||
|
// puts("cuDNN");
|
||||||
|
// g->getOutputs()[0]->printData();
|
||||||
|
runtime->run(bestGraph, tuningOp);
|
||||||
|
// puts("Output");
|
||||||
|
// bestGraph->getOutputs()[0]->printData();
|
||||||
|
|
||||||
EXPECT_TRUE(go0->equalData(bgo0));
|
auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
|
||||||
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
|
auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
|
||||||
bestGraph->getOutputs()[0]->getRawDataPtr<void *>());
|
|
||||||
|
EXPECT_TRUE(go0->equalData(bgo0));
|
||||||
|
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
|
||||||
|
bestGraph->getOutputs()[0]->getRawDataPtr<void *>());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TEST(Mutator, Conv9x9) {
|
// TEST(Mutator, Conv9x9) {
|
||||||
|
|
Loading…
Reference in New Issue