forked from jiuyuan/InfiniTensor
Compare commits
1 Commits
master
...
dist_bench
Author | SHA1 | Date |
---|---|---|
Bolun | 25a3cedeb0 |
|
@ -30,6 +30,12 @@ def parse_args():
|
|||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_torch",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to run model using PyTorch.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
|
@ -40,6 +46,7 @@ def parse_args():
|
|||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
args.run_torch
|
||||
)
|
||||
|
||||
|
||||
|
@ -153,12 +160,49 @@ def getDiff(base, test):
|
|||
f"Max relative difference: {max_relative_diff}")
|
||||
return max_absolute_diff, max_relative_diff
|
||||
|
||||
def run_pytorch(model, n=10):
|
||||
from onnxsim import simplify
|
||||
from onnx import version_converter
|
||||
from onnx2torch import convert
|
||||
import torch
|
||||
import torch_mlu
|
||||
|
||||
model, check = simplify(model)
|
||||
target_version = 16
|
||||
converted_model = version_converter.convert_version(model, target_version)
|
||||
device = torch.device("mlu")
|
||||
torch_model = convert(converted_model)
|
||||
torch_model.to(device)
|
||||
torch_model.eval()
|
||||
print("Run model using pytorch.")
|
||||
|
||||
runtime = backend.BangRuntime(0)
|
||||
stub = OnnxStub(model, runtime)
|
||||
inputs = []
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = torch.from_numpy(np.load(f"./data/input_{i}.npy"))
|
||||
input.to(device)
|
||||
inputs.append(input)
|
||||
|
||||
torch.mlu.synchronize()
|
||||
begin = time.time()
|
||||
with torch.no_grad():
|
||||
for _ in range(n):
|
||||
outputs = torch_model(*inputs)
|
||||
torch.mlu.synchronize()
|
||||
end = time.time()
|
||||
|
||||
avg_time = (end - begin) / n
|
||||
print(f"average time: {avg_time}")
|
||||
return outputs
|
||||
|
||||
|
||||
def main():
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_torch = parse_args()
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print("Generate inputs and outputs.")
|
||||
|
@ -190,6 +234,9 @@ def main():
|
|||
|
||||
for w in workers:
|
||||
w.join()
|
||||
|
||||
if run_torch:
|
||||
run_pytorch(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue