Compare commits

...

1 Commits

Author SHA1 Message Date
Bolun 25a3cedeb0 add pytorch bench 2024-03-21 02:27:32 +00:00
1 changed files with 48 additions and 1 deletions

View File

@ -30,6 +30,12 @@ def parse_args():
action="store_true", action="store_true",
help="whether to generate the standard results.", help="whether to generate the standard results.",
) )
parser.add_argument(
"--run_torch",
default=False,
action="store_true",
help="whether to run model using PyTorch.",
)
args = parser.parse_args() args = parser.parse_args()
print("arg setting: ", args) print("arg setting: ", args)
return ( return (
@ -40,6 +46,7 @@ def parse_args():
args.batch_size, args.batch_size,
args.length, args.length,
args.gen_std, args.gen_std,
args.run_torch
) )
@ -153,12 +160,49 @@ def getDiff(base, test):
f"Max relative difference: {max_relative_diff}") f"Max relative difference: {max_relative_diff}")
return max_absolute_diff, max_relative_diff return max_absolute_diff, max_relative_diff
def run_pytorch(model, n=10):
from onnxsim import simplify
from onnx import version_converter
from onnx2torch import convert
import torch
import torch_mlu
model, check = simplify(model)
target_version = 16
converted_model = version_converter.convert_version(model, target_version)
device = torch.device("mlu")
torch_model = convert(converted_model)
torch_model.to(device)
torch_model.eval()
print("Run model using pytorch.")
runtime = backend.BangRuntime(0)
stub = OnnxStub(model, runtime)
inputs = []
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = torch.from_numpy(np.load(f"./data/input_{i}.npy"))
input.to(device)
inputs.append(input)
torch.mlu.synchronize()
begin = time.time()
with torch.no_grad():
for _ in range(n):
outputs = torch_model(*inputs)
torch.mlu.synchronize()
end = time.time()
avg_time = (end - begin) / n
print(f"average time: {avg_time}")
return outputs
def main(): def main():
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_torch = parse_args()
model = onnx.load(model_path) model = onnx.load(model_path)
# generate standart output # generate standart output
if gen_std: if gen_std:
print("Generate inputs and outputs.") print("Generate inputs and outputs.")
@ -191,6 +235,9 @@ def main():
for w in workers: for w in workers:
w.join() w.join()
if run_torch:
run_pytorch(model)
if __name__ == "__main__": if __name__ == "__main__":
main() main()