diff --git a/examples/tutorial/0_interactive.py b/examples/tutorial/0_interactive.py index c164a13..4bddbe4 100644 --- a/examples/tutorial/0_interactive.py +++ b/examples/tutorial/0_interactive.py @@ -2,8 +2,15 @@ from transformers import BertForMaskedLM model = BertForMaskedLM.from_pretrained("bert-base-cased") # suppose we load BERT +import sys + +if len(sys.argv) == 1: + port=True +else: + port=int(sys.argv[1]) + from opendelta import LoraModel -delta_model = LoraModel(backbone_model=model, interactive_modify=True) +delta_model = LoraModel(backbone_model=model, interactive_modify=port) # This will visualize the backbone after modification and other information. delta_model.freeze_module(exclude=["deltas", "layernorm_embedding"], set_state_dict=True) diff --git a/opendelta/utils/interactive/templates/index.html b/opendelta/utils/interactive/templates/index.html index bf6d411..c6235f8 100644 --- a/opendelta/utils/interactive/templates/index.html +++ b/opendelta/utils/interactive/templates/index.html @@ -165,7 +165,7 @@ for (i = 0; i < coll.length; i++) { var submit = document.getElementById("submit"); submit.addEventListener("click", function() { const Http = new XMLHttpRequest(); - const url='/submit/?name='+array.join(";"); + const url='/submit/?name='+array.join("&name="); Http.open("GET", url); Http.send(); alert("Now go back to your console") diff --git a/opendelta/utils/interactive/web.py b/opendelta/utils/interactive/web.py index a7da490..2ff6575 100644 --- a/opendelta/utils/interactive/web.py +++ b/opendelta/utils/interactive/web.py @@ -107,7 +107,7 @@ class hello: class submit: def GET(self, _): global names - names = [name.strip("root.") for name in web.input().name.split(";")] + names = [name[5:] for name in web.input(name=[]).name] app.stop() def interactive(model, port=8888): @@ -120,7 +120,7 @@ def interactive(model, port=8888): print("If on your machine, open the link below for interactive modification.\n " "If on remote host, you could use port mapping, " "or run in vscode terminal, which automatically do port mapping for you.") - app.run() + app.run(port) global names print("modified_modules:") print(names)