forked from jiuyuan/InfiniTensor
Add: NMutator mode in python
This commit is contained in:
parent
b943658713
commit
9ce21200c4
|
@ -8,7 +8,8 @@ namespace infini {
|
||||||
Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId);
|
Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId);
|
||||||
vector<Tensor> runInfoGAN(int nLayers);
|
vector<Tensor> runInfoGAN(int nLayers);
|
||||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
|
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
|
||||||
Graph optimizeGraph(Graph g, Runtime runtime, bool tuning);
|
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
||||||
|
vector<int> rules);
|
||||||
void initializeGraphTensors(Graph g, double l, double r, bool useInt);
|
void initializeGraphTensors(Graph g, double l, double r, bool useInt);
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -323,7 +323,8 @@ void init_graph_builder(py::module &m) {
|
||||||
.def(py::init<Graph>())
|
.def(py::init<Graph>())
|
||||||
.def("inputs", &Handler::inputs, policy::move)
|
.def("inputs", &Handler::inputs, policy::move)
|
||||||
.def("outputs", &Handler::outputs, policy::move)
|
.def("outputs", &Handler::outputs, policy::move)
|
||||||
.def("tensor", &Handler::tensor, policy::move)
|
.def("tensor", &Handler::tensor, policy::move, "shape"_a, "dtype"_a = 1,
|
||||||
|
"tensor_type"_a = TensorType::Other)
|
||||||
.def("conv", &Handler::conv, policy::move)
|
.def("conv", &Handler::conv, policy::move)
|
||||||
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
||||||
.def("matmul", &Handler::matmul, policy::move)
|
.def("matmul", &Handler::matmul, policy::move)
|
||||||
|
@ -362,6 +363,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("getGraph", &Handler::getGraph);
|
.def("getGraph", &Handler::getGraph);
|
||||||
py::class_<Mutator, Ref<Mutator>>(m, "Mutator").def("run", &Mutator::run);
|
py::class_<Mutator, Ref<Mutator>>(m, "Mutator").def("run", &Mutator::run);
|
||||||
py::enum_<NMutator::Mode>(m, "NMutatorMode")
|
py::enum_<NMutator::Mode>(m, "NMutatorMode")
|
||||||
|
.value("Normal", NMutator::Mode::Normal)
|
||||||
.value("RuleBased", NMutator::Mode::RuleBased);
|
.value("RuleBased", NMutator::Mode::RuleBased);
|
||||||
py::class_<NMutator, Ref<NMutator>, Mutator>(m, "NMutator")
|
py::class_<NMutator, Ref<NMutator>, Mutator>(m, "NMutator")
|
||||||
.def(py::init<NMutator::Mode>())
|
.def(py::init<NMutator::Mode>())
|
||||||
|
@ -385,7 +387,8 @@ void export_test_model(py::module &m) {
|
||||||
.def("getGANGraph", &getGANGraph)
|
.def("getGANGraph", &getGANGraph)
|
||||||
.def("getConvtransposedNHWC", &getConvtransposedNHWC)
|
.def("getConvtransposedNHWC", &getConvtransposedNHWC)
|
||||||
.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
|
.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
|
||||||
"tuning"_a = false)
|
"tuning"_a = false, "mode"_a = NMutator::Mode::Normal,
|
||||||
|
"rules"_a = vector<int>{})
|
||||||
.def("initializeGraphTensors", &initializeGraphTensors, "g"_a,
|
.def("initializeGraphTensors", &initializeGraphTensors, "g"_a,
|
||||||
"l"_a = -0.1, "r"_a = 0.1, "useInt"_a = false);
|
"l"_a = -0.1, "r"_a = 0.1, "useInt"_a = false);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -127,14 +127,23 @@ void initializeGraphTensors(Graph g, double l, double r, bool useInt) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning) {
|
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
||||||
|
vector<int> rules) {
|
||||||
auto runtime = as<CudaRuntimeObj>(_runtime);
|
auto runtime = as<CudaRuntimeObj>(_runtime);
|
||||||
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||||
|
// vector<int>{3, 2, 2, 5, 8, 8, 6, 90}); // Conv2gemm
|
||||||
auto mutator =
|
// vector<int>{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90}); // TConv
|
||||||
make_ref<NMutator>(NMutator::Mode::RuleBased,
|
Ref<NMutator> mutator;
|
||||||
vector<int>{3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90});
|
if (mode == NMutator::Mode::Normal) {
|
||||||
|
dbg(mode);
|
||||||
|
mutator = make_ref<NMutator>(mode);
|
||||||
|
} else if (mode == NMutator::Mode::RuleBased) {
|
||||||
|
dbg(mode, rules);
|
||||||
|
IT_ASSERT_TODO(rules.size() > 0);
|
||||||
|
mutator = make_ref<NMutator>(mode, rules);
|
||||||
|
} else
|
||||||
|
IT_TODO_HALT();
|
||||||
vector<Graph> bestGraphs;
|
vector<Graph> bestGraphs;
|
||||||
SearchEngine searchEngine(runtime, mutator);
|
SearchEngine searchEngine(runtime, mutator);
|
||||||
bestGraphs.emplace_back(searchEngine.run(g));
|
bestGraphs.emplace_back(searchEngine.run(g));
|
||||||
|
@ -157,8 +166,8 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning) {
|
||||||
t->setData(ZeroGenerator());
|
t->setData(ZeroGenerator());
|
||||||
}
|
}
|
||||||
runtime->run(g);
|
runtime->run(g);
|
||||||
dbg("Baseline graph");
|
// dbg("Baseline graph");
|
||||||
printGraph(g);
|
// printGraph(g);
|
||||||
// dbg(runtme->getPerfTime(g, true));
|
// dbg(runtme->getPerfTime(g, true));
|
||||||
|
|
||||||
for (size_t i = 0; i < bestGraphs.size(); i++) {
|
for (size_t i = 0; i < bestGraphs.size(); i++) {
|
||||||
|
@ -194,8 +203,8 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning) {
|
||||||
// dbg(runtime->timeWithCudaGraph(bestGraph));
|
// dbg(runtime->timeWithCudaGraph(bestGraph));
|
||||||
}
|
}
|
||||||
|
|
||||||
dbg("Best graph");
|
// dbg("Best graph");
|
||||||
printGraph(bestGraph);
|
// printGraph(bestGraph);
|
||||||
return bestGraph;
|
return bestGraph;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -14,67 +14,6 @@ def to_pytorch_tensor(tensor) -> torch.Tensor:
|
||||||
return tt.reshape(tensor.shape())
|
return tt.reshape(tensor.shape())
|
||||||
|
|
||||||
|
|
||||||
def run_InfoGAN_return_tesnor(n_layers: int):
|
|
||||||
if_tensors = ft.runInfoGAN(n_layers)
|
|
||||||
tensors = [to_pytorch_tensor(t) for t in if_tensors]
|
|
||||||
return tensors
|
|
||||||
|
|
||||||
|
|
||||||
def read_and_check():
|
|
||||||
for n_layers in range(1, 6):
|
|
||||||
ans = torch.load(f'torch_{n_layers}layers_0.pt')
|
|
||||||
x = torch.load(f'torch_{n_layers}layers_1.pt')
|
|
||||||
print(f'=== {n_layers} layers ===')
|
|
||||||
print(x.abs().max())
|
|
||||||
|
|
||||||
|
|
||||||
def run_e2e_InfoGAN():
|
|
||||||
data = []
|
|
||||||
for n_layers in range(5, 6):
|
|
||||||
tensors = run_InfoGAN_return_tesnor(n_layers)
|
|
||||||
for i, t in enumerate(tensors):
|
|
||||||
torch.save(t, f'torch_{n_layers}layers_{i}.pt')
|
|
||||||
print(f'============ {n_layers} layers = = =')
|
|
||||||
ans, x = tensors
|
|
||||||
print(f'Allclose {torch.allclose(ans, x)}')
|
|
||||||
|
|
||||||
# Print error numbers
|
|
||||||
tot = np.product(ans.shape)
|
|
||||||
data.append([])
|
|
||||||
for i in range(0, 10):
|
|
||||||
tol = 10**(-i)
|
|
||||||
clo = torch.isclose(ans, x, atol=tol, rtol=tol).sum().item()
|
|
||||||
print(f'0.1^{i} close: {clo}/{tot} = {clo/tot}')
|
|
||||||
data[-1].append(clo/tot)
|
|
||||||
|
|
||||||
rel_err = torch.abs((ans-x)/ans)
|
|
||||||
print(rel_err, rel_err.max())
|
|
||||||
print(f'ans = {ans}')
|
|
||||||
print(f'x = {x}')
|
|
||||||
|
|
||||||
# # Plot CDF
|
|
||||||
# fig, axes = plt.subplots(9,1)
|
|
||||||
# print(axes)
|
|
||||||
# for i, ax in enumerate(axes):
|
|
||||||
# print(i)
|
|
||||||
# ax:plt.Axes
|
|
||||||
# ax.hist(torch.flatten(rel_err), density=True, cumulative=True, label='CDF',
|
|
||||||
# histtype='step', alpha=0.8, color='k')
|
|
||||||
# ax.set_xlim(0, 10**(-i))
|
|
||||||
# # ax.set_title('')
|
|
||||||
# plt.show()
|
|
||||||
# plt.savefig('a.pdf')
|
|
||||||
df = pd.DataFrame(data)
|
|
||||||
print(df.to_string())
|
|
||||||
df.set_axis([f'0.1^{i}' for i in range(0, 10)], axis=1, inplace=True)
|
|
||||||
print(df.to_string())
|
|
||||||
df.to_csv('a.csv')
|
|
||||||
|
|
||||||
|
|
||||||
def getSingleConvT(runtime):
|
|
||||||
return ft.getConvtransposedNHWC(runtime, [1, 2, 2, 448], 1)
|
|
||||||
|
|
||||||
|
|
||||||
def save_onnx(opt_g: ft.Graph, filename: str):
|
def save_onnx(opt_g: ft.Graph, filename: str):
|
||||||
stub = OnnxStub.from_graph(opt_g)
|
stub = OnnxStub.from_graph(opt_g)
|
||||||
with open(filename, "wb") as f:
|
with open(filename, "wb") as f:
|
||||||
|
@ -93,7 +32,7 @@ def run_and_evaluate(runtime, g):
|
||||||
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g)}')
|
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g)}')
|
||||||
|
|
||||||
|
|
||||||
def run_graph_get_output(runtime, g):
|
def run_graph_get_output_as_torch_tensor(runtime, g):
|
||||||
ft.initializeGraphTensors(g)
|
ft.initializeGraphTensors(g)
|
||||||
runtime.run(g, True)
|
runtime.run(g, True)
|
||||||
runtime.run(g, False)
|
runtime.run(g, False)
|
||||||
|
@ -114,6 +53,9 @@ def compare_tensors(ans, x):
|
||||||
print(f'0.1^{i} close: {clo}/{tot} = {clo/tot}')
|
print(f'0.1^{i} close: {clo}/{tot} = {clo/tot}')
|
||||||
data.append(clo/tot)
|
data.append(clo/tot)
|
||||||
|
|
||||||
|
# for i, t in enumerate(tensors):
|
||||||
|
# torch.save(t, f'torch_{n_layers}layers_{i}.pt')
|
||||||
|
|
||||||
# rel_err = torch.abs((ans-x)/ans)
|
# rel_err = torch.abs((ans-x)/ans)
|
||||||
# print(f'rel_err = {rel_err}')
|
# print(f'rel_err = {rel_err}')
|
||||||
# print(f'max rel err = {rel_err.max()}')
|
# print(f'max rel err = {rel_err.max()}')
|
||||||
|
@ -122,9 +64,9 @@ def compare_tensors(ans, x):
|
||||||
|
|
||||||
|
|
||||||
def verify_graphs(runtime, g_original, g_new):
|
def verify_graphs(runtime, g_original, g_new):
|
||||||
ans = run_graph_get_output(runtime, g_original)
|
ans = run_graph_get_output_as_torch_tensor(runtime, g_original)
|
||||||
x = run_graph_get_output(runtime, g_new)
|
x = run_graph_get_output_as_torch_tensor(runtime, g_new)
|
||||||
compare_tensors(runtime, ans, x)
|
compare_tensors(ans, x)
|
||||||
|
|
||||||
|
|
||||||
def evluate_GANs():
|
def evluate_GANs():
|
||||||
|
@ -133,7 +75,8 @@ def evluate_GANs():
|
||||||
for batch in [1, 16]:
|
for batch in [1, 16]:
|
||||||
if True:
|
if True:
|
||||||
original_g = ft.getGANGraph(batch, runtime, 5, model_id)
|
original_g = ft.getGANGraph(batch, runtime, 5, model_id)
|
||||||
g = ft.optimizeGraph(original_g, runtime, tuning=False)
|
g = ft.optimizeGraph(original_g, runtime, False, {
|
||||||
|
3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90})
|
||||||
else:
|
else:
|
||||||
g = load_onnx(runtime)
|
g = load_onnx(runtime)
|
||||||
save_onnx(
|
save_onnx(
|
||||||
|
@ -142,21 +85,43 @@ def evluate_GANs():
|
||||||
run_and_evaluate(runtime, g)
|
run_and_evaluate(runtime, g)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_convTranspose2d(runtime):
|
||||||
|
handler = ft.GraphHandler(runtime)
|
||||||
|
input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
|
||||||
|
w = handler.tensor([56, 1, 9, 9], tensor_type=ft.TensorType.Initialized)
|
||||||
|
handler.convTransposed2d(input, w, None, 3, 3, 4, 4, 1, 1, 1, 1)
|
||||||
|
return handler.getGraph()
|
||||||
|
|
||||||
|
|
||||||
|
def construct_conv(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
|
||||||
|
handler = ft.GraphHandler(runtime)
|
||||||
|
# input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
|
||||||
|
# w = handler.tensor([12, 56, 1, 1], tensor_type=ft.TensorType.Initialized)
|
||||||
|
# handler.conv(input, w, None, 0, 0, 1, 1, 1, 1)
|
||||||
|
input = handler.tensor([n, c, h, w], tensor_type=ft.TensorType.Input)
|
||||||
|
w = handler.tensor([f, c, r, s], tensor_type=ft.TensorType.Initialized)
|
||||||
|
handler.conv(input, w, None, pad, pad, stride, stride, dilation, dilation)
|
||||||
|
return handler.getGraph()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
runtime = ft.cuda_runtime()
|
runtime = ft.cuda_runtime()
|
||||||
# run_e2e_InfoGAN()
|
graphs = [
|
||||||
# runSingleConvT()
|
# (construct_conv(runtime, 16, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1'), # FSRCNN Conv_2 1x1
|
||||||
# read_and_check()
|
# (construct_conv(runtime, 1, 12, 32, 32, 12, 3, 3, 1, 1, 1), 'conv3x3'), # FSRCNN Conv_4 3x3
|
||||||
for batch in [1, 16]:
|
# ft.getGANGraph(batch, runtime, 5, 1)
|
||||||
if True:
|
# construct_convTranspose2d(runtime)
|
||||||
original_g = ft.getGANGraph(batch, runtime, 5, 1)
|
(load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
|
||||||
# original_g = ft.getConvtransposedNHWC(runtime, [1, 1, 1, 228], 0) # ConvTranspose 2x2
|
]
|
||||||
# original_g = ft.getConvtransposedNHWC(runtime, [16, 2, 2, 448], 1) # ConvTranspose 4x4
|
|
||||||
g = ft.optimizeGraph(original_g, runtime, tuning=False)
|
|
||||||
else:
|
|
||||||
g = load_onnx(runtime)
|
|
||||||
save_onnx(g, f"dcgan_optimized_{batch}.onnx")
|
|
||||||
|
|
||||||
|
for original_g, name in graphs:
|
||||||
|
print(f"=== {name}")
|
||||||
|
if True: # Optimization
|
||||||
|
save_onnx(original_g, f"orig_{name}.onnx")
|
||||||
|
g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
|
||||||
|
[3, 2, 2, 5, 8, 8, 6, 90])
|
||||||
|
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.Normal)
|
||||||
|
|
||||||
|
save_onnx(g, f"optimized_{name}.onnx")
|
||||||
verify_graphs(runtime, original_g, g)
|
verify_graphs(runtime, original_g, g)
|
||||||
|
|
||||||
run_and_evaluate(runtime, g)
|
run_and_evaluate(runtime, g)
|
||||||
|
|
Loading…
Reference in New Issue