Skip to content

Commit 4ab399f

Browse files
Allow config to include non-pickle-able values (#7415)
* fixes * lint * add changeset * route utils --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent c2dfc59 commit 4ab399f

File tree

6 files changed

+83
-15
lines changed

6 files changed

+83
-15
lines changed

.changeset/mean-bushes-hide.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": patch
3+
---
4+
5+
fix:Allow config to include non-pickle-able values

gradio/processing_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,11 @@ def _move_to_cache(d: dict):
289289
return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)
290290

291291

292-
def add_root_url(data, root_url) -> dict:
292+
def add_root_url(data: dict, root_url: str, previous_root_url: str | None) -> dict:
293293
def _add_root_url(file_dict: dict):
294294
if not client_utils.is_http_url_like(file_dict["url"]):
295+
if previous_root_url and file_dict["url"].startswith(previous_root_url):
296+
file_dict["url"] = file_dict["url"][len(previous_root_url) :]
295297
file_dict["url"] = f'{root_url}{file_dict["url"]}'
296298
return file_dict
297299

gradio/route_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from starlette.datastructures import FormData, Headers, UploadFile
1717
from starlette.formparsers import MultiPartException, MultipartPart
1818

19-
from gradio import utils
19+
from gradio import processing_utils, utils
2020
from gradio.data_classes import PredictBody
2121
from gradio.exceptions import Error
2222
from gradio.helpers import EventData
@@ -561,3 +561,16 @@ async def parse(self) -> FormData:
561561
def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None:
562562
for file, dest in zip(files, destinations):
563563
shutil.move(file, dest)
564+
565+
566+
def update_root_in_config(config: dict, root: str) -> dict:
567+
"""
568+
Updates the root "key" in the config dictionary to the new root url. If the
569+
root url has changed, all of the urls in the config that correspond to component
570+
file urls are updated to use the new root url.
571+
"""
572+
previous_root = config.get("root", None)
573+
if previous_root is None or previous_root != root:
574+
config["root"] = root
575+
config = processing_utils.add_root_url(config, root, previous_root)
576+
return config

gradio/routes.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import asyncio
77
import contextlib
8-
import copy
98
import sys
109

1110
if sys.version_info >= (3, 9):
@@ -311,19 +310,18 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
311310
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
312311
mimetypes.add_type("application/javascript", ".js")
313312
blocks = app.get_blocks()
314-
root_path = route_utils.get_root_url(
313+
root = route_utils.get_root_url(
315314
request=request, route_path="/", root_path=app.root_path
316315
)
317316
if app.auth is None or user is not None:
318-
config = copy.deepcopy(app.get_blocks().config)
319-
config["root"] = root_path
320-
config = add_root_url(config, root_path)
317+
config = app.get_blocks().config
318+
config = route_utils.update_root_in_config(config, root)
321319
else:
322320
config = {
323321
"auth_required": True,
324322
"auth_message": blocks.auth_message,
325323
"space_id": app.get_blocks().space_id,
326-
"root": root_path,
324+
"root": root,
327325
}
328326

329327
try:
@@ -354,13 +352,12 @@ def api_info():
354352
@app.get("/config/", dependencies=[Depends(login_check)])
355353
@app.get("/config", dependencies=[Depends(login_check)])
356354
def get_config(request: fastapi.Request):
357-
config = copy.deepcopy(app.get_blocks().config)
358-
root_path = route_utils.get_root_url(
355+
config = app.get_blocks().config
356+
root = route_utils.get_root_url(
359357
request=request, route_path="/config", root_path=app.root_path
360358
)
361-
config["root"] = root_path
362-
config = add_root_url(config, root_path)
363-
return config
359+
config = route_utils.update_root_in_config(config, root)
360+
return ORJSONResponse(content=config)
364361

365362
@app.get("/static/{path:path}")
366363
def static_resource(path: str):
@@ -577,7 +574,7 @@ async def predict(
577574
root_path = route_utils.get_root_url(
578575
request=request, route_path=f"/api/{api_name}", root_path=app.root_path
579576
)
580-
output = add_root_url(output, root_path)
577+
output = add_root_url(output, root_path, None)
581578
return output
582579

583580
@app.get("/queue/data", dependencies=[Depends(login_check)])
@@ -634,7 +631,7 @@ async def sse_stream(request: fastapi.Request):
634631
"success": False,
635632
}
636633
if message:
637-
add_root_url(message, root_path)
634+
add_root_url(message, root_path, None)
638635
yield f"data: {json.dumps(message)}\n\n"
639636
if message["msg"] == ServerMessage.process_completed:
640637
blocks._queue.pending_event_ids_session[

test/test_processing_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,42 @@ def test_video_conversion_returns_original_video_if_fails(
332332
)
333333
# If the conversion succeeded it'd be .mp4
334334
assert Path(playable_vid).suffix == ".avi"
335+
336+
337+
def test_add_root_url():
338+
data = {
339+
"file": {
340+
"path": "path",
341+
"url": "/file=path",
342+
},
343+
"file2": {
344+
"path": "path2",
345+
"url": "https://www.gradio.app",
346+
},
347+
}
348+
root_url = "http://localhost:7860"
349+
expected = {
350+
"file": {
351+
"path": "path",
352+
"url": f"{root_url}/file=path",
353+
},
354+
"file2": {
355+
"path": "path2",
356+
"url": "https://www.gradio.app",
357+
},
358+
}
359+
assert processing_utils.add_root_url(data, root_url, None) == expected
360+
new_root_url = "https://1234.gradio.live"
361+
new_expected = {
362+
"file": {
363+
"path": "path",
364+
"url": f"{root_url}/file=path",
365+
},
366+
"file2": {
367+
"path": "path2",
368+
"url": "https://www.gradio.app",
369+
},
370+
}
371+
assert (
372+
processing_utils.add_root_url(expected, root_url, new_root_url) == new_expected
373+
)

test/test_routes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,18 @@ def test_proxy_does_not_leak_hf_token_externally(self):
439439
r = app.build_proxy_request("https://google.com")
440440
assert "authorization" not in dict(r.headers)
441441

442+
def test_can_get_config_that_includes_non_pickle_able_objects(self):
443+
my_dict = {"a": 1, "b": 2, "c": 3}
444+
with Blocks() as demo:
445+
gr.JSON(my_dict.keys())
446+
447+
app, _, _ = demo.launch(prevent_thread_lock=True)
448+
client = TestClient(app)
449+
response = client.get("/")
450+
assert response.is_success
451+
response = client.get("/config/")
452+
assert response.is_success
453+
442454

443455
class TestApp:
444456
def test_create_app(self):

0 commit comments

Comments
 (0)