forked from jiuyuan/InfiniTensor
Compare commits
1 Commits
master
...
dist_bench
Author | SHA1 | Date |
---|---|---|
Bolun | 25a3cedeb0 |
|
@ -30,6 +30,12 @@ def parse_args():
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="whether to generate the standard results.",
|
help="whether to generate the standard results.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_torch",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to run model using PyTorch.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print("arg setting: ", args)
|
print("arg setting: ", args)
|
||||||
return (
|
return (
|
||||||
|
@ -40,6 +46,7 @@ def parse_args():
|
||||||
args.batch_size,
|
args.batch_size,
|
||||||
args.length,
|
args.length,
|
||||||
args.gen_std,
|
args.gen_std,
|
||||||
|
args.run_torch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,12 +160,49 @@ def getDiff(base, test):
|
||||||
f"Max relative difference: {max_relative_diff}")
|
f"Max relative difference: {max_relative_diff}")
|
||||||
return max_absolute_diff, max_relative_diff
|
return max_absolute_diff, max_relative_diff
|
||||||
|
|
||||||
|
def run_pytorch(model, n=10):
|
||||||
|
from onnxsim import simplify
|
||||||
|
from onnx import version_converter
|
||||||
|
from onnx2torch import convert
|
||||||
|
import torch
|
||||||
|
import torch_mlu
|
||||||
|
|
||||||
|
model, check = simplify(model)
|
||||||
|
target_version = 16
|
||||||
|
converted_model = version_converter.convert_version(model, target_version)
|
||||||
|
device = torch.device("mlu")
|
||||||
|
torch_model = convert(converted_model)
|
||||||
|
torch_model.to(device)
|
||||||
|
torch_model.eval()
|
||||||
|
print("Run model using pytorch.")
|
||||||
|
|
||||||
|
runtime = backend.BangRuntime(0)
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
inputs = []
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = torch.from_numpy(np.load(f"./data/input_{i}.npy"))
|
||||||
|
input.to(device)
|
||||||
|
inputs.append(input)
|
||||||
|
|
||||||
|
torch.mlu.synchronize()
|
||||||
|
begin = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(n):
|
||||||
|
outputs = torch_model(*inputs)
|
||||||
|
torch.mlu.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
avg_time = (end - begin) / n
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_torch = parse_args()
|
||||||
|
|
||||||
model = onnx.load(model_path)
|
model = onnx.load(model_path)
|
||||||
|
|
||||||
|
|
||||||
# generate standart output
|
# generate standart output
|
||||||
if gen_std:
|
if gen_std:
|
||||||
print("Generate inputs and outputs.")
|
print("Generate inputs and outputs.")
|
||||||
|
@ -191,6 +235,9 @@ def main():
|
||||||
for w in workers:
|
for w in workers:
|
||||||
w.join()
|
w.join()
|
||||||
|
|
||||||
|
if run_torch:
|
||||||
|
run_pytorch(model)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue