forked from jiuyuan/InfiniTensor
Add: search engine uses estimated time
This commit is contained in:
parent
26f0d13c26
commit
e72fe79168
|
@ -59,10 +59,11 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
* execution happens.
|
||||
*
|
||||
* @param graph
|
||||
* @param profiling Whether to print breakdown of time
|
||||
* @param printProfiling Whether to print breakdown of time
|
||||
* @return double Return the sum of perf time for each operator
|
||||
*/
|
||||
double getPerfTime(const Graph &graph, bool profiling = false) const;
|
||||
double getPerfTime(const Graph &graph, bool printProfiling = false,
|
||||
bool allowEstimation = false) const;
|
||||
Blob allocBlob(size_t size);
|
||||
bool isCpu() const {
|
||||
return device == Device::CPU || device == Device::INTELCPU;
|
||||
|
|
|
@ -11,11 +11,15 @@ class SearchEngine {
|
|||
private:
|
||||
Runtime runtimeExec;
|
||||
Ref<Mutator> mutator;
|
||||
std::function<bool(const Graph &, const Graph &)> graphTimeComparer;
|
||||
|
||||
public:
|
||||
SearchEngine(Runtime _runtime, Ref<Mutator> _mutator) {
|
||||
runtimeExec = _runtime;
|
||||
mutator = _mutator;
|
||||
SearchEngine(Runtime runtime, Ref<Mutator> mutator)
|
||||
: runtimeExec(runtime), mutator(mutator) {
|
||||
// Compare graph with estimated time
|
||||
graphTimeComparer = [this](const Graph &a, const Graph &b) -> bool {
|
||||
return getEstimatedGraphPerf(a) < getEstimatedGraphPerf(b);
|
||||
};
|
||||
}
|
||||
~SearchEngine() {}
|
||||
|
||||
|
@ -24,11 +28,7 @@ class SearchEngine {
|
|||
3; // cut nodes whose #in + #out >= partitionThreshold
|
||||
size_t GRAPH_SIZE = 16; // num of best graphs.
|
||||
|
||||
private: // Composed objects
|
||||
std::shared_ptr<Mutator> mutationEngine;
|
||||
|
||||
public:
|
||||
std::shared_ptr<Mutator> getMutationEngine() { return mutationEngine; };
|
||||
struct GroupEdge {
|
||||
int v, next;
|
||||
GroupEdge() = delete;
|
||||
|
@ -38,10 +38,7 @@ class SearchEngine {
|
|||
std::shared_ptr<Graph> graph;
|
||||
double perf = INFINITY;
|
||||
};
|
||||
class MetaGraph { // a graph of subgraphs, for searching.
|
||||
public:
|
||||
MetaGraph() {}
|
||||
~MetaGraph() {}
|
||||
struct MetaGraph { // a graph of subgraphs, for searching.
|
||||
struct Node {
|
||||
Graph graph;
|
||||
std::vector<int> suc;
|
||||
|
@ -51,7 +48,7 @@ class SearchEngine {
|
|||
std::vector<Node> nodes;
|
||||
};
|
||||
|
||||
Graph run(const Graph graph); // entrance of search engine.
|
||||
Graph run(const Graph graph); // entrance to search engine.
|
||||
std::vector<Graph> search(const Graph &graph); // search for a partition.
|
||||
|
||||
private:
|
||||
|
@ -76,5 +73,9 @@ class SearchEngine {
|
|||
* branch.
|
||||
*/
|
||||
bool isMultiBranchMergable(const Graph graph);
|
||||
|
||||
double getEstimatedGraphPerf(Graph graph) {
|
||||
return runtimeExec->getPerfTime(graph, false, true);
|
||||
}
|
||||
};
|
||||
} // namespace infini
|
||||
|
|
|
@ -30,6 +30,7 @@ class MemBoundObj : public OperatorObj {
|
|||
pair<const nnet::Expr, HashType> getSimplifiedNnetExpr() const {
|
||||
return {expr, hash};
|
||||
}
|
||||
double getEstimatedTime() const { return exec_time; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
from contextlib import redirect_stdout
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tvm
|
||||
|
@ -8,99 +9,111 @@ import json
|
|||
import logging
|
||||
|
||||
USE_CACHE = True
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger('InfiniTensor')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype,
|
||||
tvm_code, func_name, nnet_expression: str,
|
||||
nnet_simplified_expression: str, hash_code=None):
|
||||
assert len(input_tensors) == len(input_dtypes)
|
||||
|
||||
logging.debug(f'Work on hash {hash_code}')
|
||||
|
||||
logger.debug(f'Work on hash {hash_code}')
|
||||
dir_name = os.path.join(".cache", "generated_kernels", str(hash_code))
|
||||
|
||||
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
|
||||
|
||||
so_fn = os.path.join(dir_name, f"{func_name}.so")
|
||||
config_fn = os.path.join(dir_name, "config_so.json")
|
||||
|
||||
print("Generating Ansor op: ")
|
||||
print(tvm_code)
|
||||
|
||||
print("Input shape: ")
|
||||
print(input_tensors)
|
||||
print("Output shape: ")
|
||||
print(output_tensor)
|
||||
|
||||
desc_fn = os.path.join(dir_name, "desc.txt")
|
||||
log_fn = os.path.join(dir_name, f"ansor_{func_name}_log.json")
|
||||
out_fn = os.path.join(dir_name, "out.txt")
|
||||
|
||||
logger.debug(f"Generating Ansor op: {tvm_code}")
|
||||
logger.debug(f"Input shape: {input_tensors}")
|
||||
logger.debug(f"Output shape: {output_tensor}")
|
||||
|
||||
if USE_CACHE and hash_code is not None:
|
||||
if os.path.exists(dir_name) and \
|
||||
os.path.exists(so_fn) and \
|
||||
os.path.exists(config_fn):
|
||||
os.path.exists(so_fn) and \
|
||||
os.path.exists(config_fn):
|
||||
print(f"Use cache in {dir_name}")
|
||||
with open(config_fn, "r") as config_fin:
|
||||
config = json.loads(config_fin.read().strip())
|
||||
conv_time = config["conv_time"]
|
||||
|
||||
logger.debug(f'Find tuning log for {hash_code}')
|
||||
logger.info(f'Find tuning log for {hash_code} in {so_fn}')
|
||||
return so_fn, conv_time
|
||||
|
||||
logger.info(f"TVM Tuning kernel with hash {hash_code}. See {out_fn}")
|
||||
|
||||
time_start = time.perf_counter()
|
||||
# Print descriptions of the task
|
||||
if USE_CACHE and hash_code is not None:
|
||||
with redirect_stdout(open(desc_fn, "w")):
|
||||
print("====NNET tensor expression====")
|
||||
print(nnet_expression+"\n")
|
||||
print("====NNET simplified tensor expression====")
|
||||
print(nnet_simplified_expression+"\n")
|
||||
print("====TVM compute====")
|
||||
print(tvm_code+"\n")
|
||||
print("Input shape: ", input_tensors)
|
||||
print("Output shape: ", output_tensor)
|
||||
|
||||
@auto_scheduler.register_workload(func_name)
|
||||
def compute():
|
||||
_locals = locals()
|
||||
exec(tvm_code, {'tvm': tvm, 'te': te, 'tir': tir, 'topi': topi}, _locals)
|
||||
exec(tvm_code, {'tvm': tvm, 'te': te,
|
||||
'tir': tir, 'topi': topi}, _locals)
|
||||
return _locals['ret']
|
||||
|
||||
|
||||
target = tvm.target.Target("cuda")
|
||||
|
||||
task = auto_scheduler.SearchTask(func=func_name, args=(), target=target)
|
||||
|
||||
# Inspect the computational graph
|
||||
print("Computational DAG:")
|
||||
print(task.compute_dag)
|
||||
with redirect_stdout(open(out_fn, 'w')):
|
||||
# Inspect the computational graph
|
||||
print("Computational DAG:")
|
||||
print(task.compute_dag)
|
||||
|
||||
log_file = f"ansor_{func_name}_log.json"
|
||||
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
|
||||
tune_option = auto_scheduler.TuningOptions(
|
||||
num_measure_trials=10,
|
||||
runner=measure_ctx.runner,
|
||||
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
|
||||
verbose=2,
|
||||
)
|
||||
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
|
||||
tune_option = auto_scheduler.TuningOptions(
|
||||
num_measure_trials=10,
|
||||
runner=measure_ctx.runner,
|
||||
measure_callbacks=[auto_scheduler.RecordToFile(log_fn)],
|
||||
verbose=2,
|
||||
)
|
||||
|
||||
# Run auto-tuning (search)
|
||||
task.tune(tune_option)
|
||||
# Apply the best schedule
|
||||
sch, args = task.apply_best(log_file)
|
||||
# Run auto-tuning (search)
|
||||
task.tune(tune_option)
|
||||
# Apply the best schedule
|
||||
sch, args = task.apply_best(log_fn)
|
||||
|
||||
# Kill the measurement process
|
||||
del measure_ctx
|
||||
|
||||
func = tvm.build(sch, args, target, name=func_name)
|
||||
func.export_library(so_fn)
|
||||
|
||||
ctx = tvm.cuda(0)
|
||||
input_a = []
|
||||
for i, (shape, dtype) in enumerate(zip(input_tensors, input_dtypes)):
|
||||
a_np = np.random.uniform(size=shape).astype(dtype)
|
||||
input_a.append(tvm.nd.array(a_np, ctx))
|
||||
a_out = tvm.nd.array(np.zeros(output_tensor, dtype=output_dtype), ctx)
|
||||
func(a_out, *input_a)
|
||||
evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
|
||||
conv_time = evaluator(a_out, *input_a).mean * 1e3
|
||||
|
||||
time_end = time.perf_counter()
|
||||
|
||||
# Kill the measurement process
|
||||
del measure_ctx
|
||||
|
||||
func = tvm.build(sch, args, target, name=func_name)
|
||||
func.export_library(so_fn)
|
||||
|
||||
ctx = tvm.cuda(0)
|
||||
input_a = []
|
||||
for i, (shape, dtype) in enumerate(zip(input_tensors, input_dtypes)):
|
||||
a_np = np.random.uniform(size=shape).astype(dtype)
|
||||
input_a.append(tvm.nd.array(a_np, ctx))
|
||||
a_out = tvm.nd.array(np.zeros(output_tensor, dtype=output_dtype), ctx)
|
||||
func(a_out, *input_a)
|
||||
evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
|
||||
conv_time = evaluator(a_out, *input_a).mean * 1e3
|
||||
|
||||
print("====NNET tensor expression====")
|
||||
print(nnet_expression+"\n")
|
||||
print("====NNET simplified tensor expression====")
|
||||
print(nnet_simplified_expression+"\n")
|
||||
print("====Time====")
|
||||
print(conv_time)
|
||||
|
||||
if USE_CACHE and hash_code is not None:
|
||||
with open(config_fn, "w") as config_fout:
|
||||
config_fout.write(json.dumps({
|
||||
"conv_time": conv_time,
|
||||
"tuning_time": time_end - time_start,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
|
||||
}, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
return so_fn, conv_time
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/blob.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "operators/membound.h"
|
||||
#include "utils/data_generator.h"
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
|
@ -56,7 +57,8 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
|||
printProfilingData(totalTime, opTime, opCnt);
|
||||
}
|
||||
|
||||
double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
||||
double RuntimeObj::getPerfTime(const Graph &graph, bool profiling,
|
||||
bool allowEstimation) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
// Statistics
|
||||
|
@ -70,11 +72,16 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
|||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
|
||||
PerfRecord record;
|
||||
double time = -1e9;
|
||||
// Tune the kernel if there is no record
|
||||
if (!perfData) {
|
||||
if (perfData) {
|
||||
time = perfData->time;
|
||||
} else if (allowEstimation && op->getOpType() == OpType::MemBound) {
|
||||
time = as<MemBoundObj>(op)->getEstimatedTime();
|
||||
} else {
|
||||
// TODO: should tenosrs automatically allocate when access data?
|
||||
// allocate memory for empty tensors and release it after profiling
|
||||
// allocate memory for empty tensors and release it after
|
||||
// profiling
|
||||
TensorVec allocatedTensors;
|
||||
for (auto t : op->getInputs())
|
||||
if (!t->hasData())
|
||||
|
@ -88,21 +95,20 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
|||
}
|
||||
|
||||
// Profile operators and record the results
|
||||
record = kernel->tune(op, this);
|
||||
PerfRecord record = kernel->tune(op, this);
|
||||
time = record->time;
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
|
||||
// Free allocated memory
|
||||
for (auto t : allocatedTensors)
|
||||
t->freeData();
|
||||
} else
|
||||
record = perfData;
|
||||
}
|
||||
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
totalTime += time;
|
||||
if (profiling) {
|
||||
op->print();
|
||||
printf(" op_time %lf\n", t);
|
||||
opTime[op->getOpType()] += t;
|
||||
printf(" op_time %lf\n", time);
|
||||
opTime[op->getOpType()] += time;
|
||||
opCnt[op->getOpType()]++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,8 +32,7 @@ Graph SearchEngine::run(const Graph graph) {
|
|||
IT_ASSERT(runtimeExec == graph->getRuntime());
|
||||
std::cout << "[INFO] original graph: " << std::endl;
|
||||
std::cout << graph->toString();
|
||||
std::cout << "[INFO] perf: " << runtimeExec->getPerfTime(graph)
|
||||
<< std::endl;
|
||||
std::cout << "[INFO] perf: " << getEstimatedGraphPerf(graph) << std::endl;
|
||||
|
||||
std::vector<Graph> partitions = partitionGraph(graph);
|
||||
|
||||
|
@ -65,9 +64,7 @@ Graph SearchEngine::run(const Graph graph) {
|
|||
nextGraphs.emplace_back(tmp);
|
||||
}
|
||||
}
|
||||
std::sort(nextGraphs.begin(), nextGraphs.end(), [&](Graph x, Graph y) {
|
||||
return runtimeExec->getPerfTime(x) < runtimeExec->getPerfTime(y);
|
||||
});
|
||||
std::sort(nextGraphs.begin(), nextGraphs.end(), graphTimeComparer);
|
||||
if (nextGraphs.size() > GRAPH_SIZE) {
|
||||
nextGraphs.resize(GRAPH_SIZE);
|
||||
}
|
||||
|
@ -81,7 +78,7 @@ Graph SearchEngine::run(const Graph graph) {
|
|||
for (size_t i = 0; i < bestGraphs.size(); i++) {
|
||||
std::cout << "bestGraph " << i << ":" << std::endl;
|
||||
std::cout << bestGraphs[i]->toString();
|
||||
std::cout << "[INFO] perf: " << runtimeExec->getPerfTime(bestGraphs[i])
|
||||
std::cout << "[INFO] perf: " << getEstimatedGraphPerf(bestGraphs[i])
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
|
@ -102,9 +99,8 @@ std::vector<Graph> SearchEngine::search(const Graph &graph) {
|
|||
}
|
||||
}
|
||||
|
||||
sort(results.begin(), results.end(), [&](Graph x, Graph y) {
|
||||
return runtimeExec->getPerfTime(x) < runtimeExec->getPerfTime(y);
|
||||
}); // compare with perf time
|
||||
// compare with perf time
|
||||
std::sort(results.begin(), results.end(), graphTimeComparer);
|
||||
if (results.size() > GRAPH_SIZE) {
|
||||
results.resize(GRAPH_SIZE);
|
||||
}
|
||||
|
@ -360,9 +356,7 @@ std::vector<Graph> SearchEngine::searchMutation(
|
|||
for (auto g : nextGraphs) {
|
||||
g->dataMalloc();
|
||||
}
|
||||
std::sort(nextGraphs.begin(), nextGraphs.end(), [&](Graph x, Graph y) {
|
||||
return runtimeExec->getPerfTime(x) < runtimeExec->getPerfTime(y);
|
||||
});
|
||||
std::sort(nextGraphs.begin(), nextGraphs.end(), graphTimeComparer);
|
||||
if (nextGraphs.size() > GRAPH_SIZE) {
|
||||
nextGraphs.resize(GRAPH_SIZE);
|
||||
}
|
||||
|
@ -372,7 +366,7 @@ std::vector<Graph> SearchEngine::searchMutation(
|
|||
}
|
||||
|
||||
bool SearchEngine::isMultiBranchMergable(const Graph graph) {
|
||||
return mutationEngine->isMultiBranchMergable(graph);
|
||||
return mutator->isMultiBranchMergable(graph);
|
||||
}
|
||||
|
||||
// Split a graph into multiple independt graphs. Search engine will search for
|
||||
|
|
|
@ -24,8 +24,7 @@ string TensorObj::toString() const {
|
|||
ss << "nullptr data";
|
||||
string ret = "Tensor " + std::to_string(guid) + ", Fuid " +
|
||||
std::to_string(fuid) + ", shape " + vecToString(shape) +
|
||||
", dtype " + dtype.toString() + ", " + runtime->toString() +
|
||||
", " + ss.str() + "\n";
|
||||
", dtype " + dtype.toString();
|
||||
vector<UidBaseType> targetGuids;
|
||||
for (const auto &op : targets)
|
||||
targetGuids.emplace_back(op.lock()->getGuid());
|
||||
|
@ -34,6 +33,7 @@ string TensorObj::toString() const {
|
|||
else
|
||||
ret += ", source None";
|
||||
ret += ", targets " + vecToString(targetGuids);
|
||||
ret += ", " + runtime->toString() + ", " + ss.str();
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
#include "core/blob.h"
|
||||
#include "core/dummy_mutator.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "core/search_engine.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operators/conv.h"
|
||||
#include "test.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace infini {
|
||||
|
||||
// NHWC format
|
||||
Graph getInfoGAN(int batch, Runtime runtime) {
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
vector<Tensor> weights;
|
||||
vector<tuple<int, int, int, int>> cs{
|
||||
// Channel, kernelSize, pad, stride
|
||||
{448, 2, 0, 1}, {256, 4, 1, 2}, {128, 4, 1, 2},
|
||||
{64, 4, 1, 2}, {32, 4, 1, 2},
|
||||
};
|
||||
Tensor input = g->addTensor({batch, 1, 1, 228});
|
||||
for (auto [channel, kernelSize, pad, stride] : cs) {
|
||||
int f = input->getDims()[3]; // n, h, w, f
|
||||
auto weight =
|
||||
g->addTensor({f, kernelSize, kernelSize, channel}); // f, r, s, c
|
||||
input = g->addOp<ConvTransposed2dNHWCObj>(input, weight, nullptr, pad,
|
||||
pad, stride, stride, 1, 1)
|
||||
->getOutput();
|
||||
// TODO: activation
|
||||
}
|
||||
return g;
|
||||
}
|
||||
|
||||
void printGraph(Graph g) {
|
||||
g->print();
|
||||
puts("============ Data ============");
|
||||
for (auto t : g->getTensors()) {
|
||||
dbg(t);
|
||||
t->printData();
|
||||
}
|
||||
}
|
||||
|
||||
vector<Tensor> runInfoGAN() {
|
||||
const bool useMutatorDirectly = true;
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
|
||||
Graph g = getInfoGAN(1, cuda);
|
||||
|
||||
auto mutator =
|
||||
make_ref<NMutator>(NMutator::Mode::RuleBased,
|
||||
vector<int>{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90});
|
||||
// // Translate OP to membound without derivation
|
||||
// mutator->setToNaiveMembound();
|
||||
|
||||
vector<Graph> bestGraphs;
|
||||
SearchEngine searchEngine(cuda, mutator);
|
||||
bestGraphs.emplace_back(searchEngine.run(g));
|
||||
g->topo_sort();
|
||||
dbg(g, bestGraphs[0], bestGraphs.size());
|
||||
g->print();
|
||||
|
||||
g->dataMalloc();
|
||||
map<UidBaseType, Tensor> fuidToInputTensor;
|
||||
for (auto t : g->getInputs()) {
|
||||
IT_ASSERT(fuidToInputTensor.count(t->getFuid()) == 0);
|
||||
fuidToInputTensor[t->getFuid()] = t;
|
||||
}
|
||||
|
||||
auto gen = RandomGenerator(-1, 1, 0);
|
||||
for (auto t : g->getInputs()) {
|
||||
t->setData(gen);
|
||||
}
|
||||
for (auto t : g->getOutputs()) {
|
||||
t->setData(ZeroGenerator());
|
||||
}
|
||||
cuda->run(g);
|
||||
dbg("Baseline graph");
|
||||
printGraph(g);
|
||||
dbg(cuda->getPerfTime(g, true));
|
||||
|
||||
for (size_t i = 0; i < bestGraphs.size(); i++) {
|
||||
auto bestGraphCpu = bestGraphs[i];
|
||||
auto bestGraph = make_ref<GraphObj>(cuda, bestGraphCpu->getOperators());
|
||||
bestGraph->topo_sort();
|
||||
|
||||
bestGraph->dataMalloc();
|
||||
// Initialize inputs with random data
|
||||
for (auto t : bestGraph->getInputs()) {
|
||||
t->copyData(fuidToInputTensor[t->getFuid()]);
|
||||
}
|
||||
|
||||
// Initialize outputs with zeros
|
||||
for (auto t : bestGraph->getOutputs()) {
|
||||
t->setData(ZeroGenerator());
|
||||
}
|
||||
|
||||
dbg(bestGraph);
|
||||
dbg(bestGraph->getOutputs());
|
||||
|
||||
cuda->run(bestGraph, true); // Tune kernels
|
||||
cuda->run(bestGraph, false); // Execute transfomraed graph
|
||||
|
||||
auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
|
||||
auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
|
||||
// EXPECT_TRUE(go0->equalData(bgo0, 1e-3));
|
||||
std::cout << go0->equalData(bgo0, 1e-3) << std::endl;
|
||||
bgo0->printData();
|
||||
go0->printData();
|
||||
dbg(cuda->getPerfTime(bestGraph, true));
|
||||
|
||||
dbg("Best graph");
|
||||
printGraph(bestGraph);
|
||||
return {g->getOutputs()[0], bestGraph->getOutputs()[0]};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// TEST(ModelE2E, InfoGAN) { runInfoGAN(); }
|
||||
|
||||
} // namespace infini
|
|
@ -49,7 +49,7 @@ void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
|
|||
assert(computeOps.size() == 1);
|
||||
const auto &computeOp = computeOps[0];
|
||||
auto g = infini::make_ref<GraphObj>(in_graph->getRuntime());
|
||||
auto expr = opToExpression(computeOp);
|
||||
nnet::Expr expr = opToExpression(computeOp);
|
||||
auto inputsN = nnet::GetTensorsVisitor().get(expr);
|
||||
dbg(inputsN, expr);
|
||||
IT_ASSERT(inputsN.count("B") + inputsN.count("K") == 1,
|
||||
|
@ -258,6 +258,8 @@ nnet::Expr NMutator::opToExpression(Operator op) {
|
|||
const auto &[n, c, h, w, f, r, s] = convOp->getNCHWFRS();
|
||||
const auto &[ph, pw, sh, sw, dh, dw] = convOp->getPadStrideDilation();
|
||||
IT_ASSERT_TODO(convOp->getNumGroups() == 1);
|
||||
if (r != 4)
|
||||
return nullptr;
|
||||
IT_ASSERT_TODO(r == 4);
|
||||
IT_ASSERT_TODO(ph == pw);
|
||||
IT_ASSERT_TODO(tie(sh, sw) == tuple(2, 2));
|
||||
|
|
Loading…
Reference in New Issue