Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ async def main():
from typing_extensions import TypeVar

import mcp.types as types
from mcp.server.auth.middleware.auth_context import auth_context_var
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.experimental.request_context import Experimental
from mcp.server.lowlevel.experimental import ExperimentalHandlers
from mcp.server.lowlevel.func_inspection import create_call_wrapper
Expand Down Expand Up @@ -723,6 +725,7 @@ async def _handle_request(
logger.debug("Dispatching request of type %s", type(req).__name__)

token = None
auth_token = None
try:
# Extract request context and close_sse_stream from message metadata
request_data = None
Expand All @@ -743,6 +746,14 @@ async def _handle_request(
task_metadata = None
if hasattr(req, "params") and req.params is not None:
task_metadata = getattr(req.params, "task", None)
if request_data is not None:
scope = getattr(request_data, "scope", None)
if isinstance(scope, dict):
scope_dict = cast(dict[str, Any], scope)
user = scope_dict.get("user")
if isinstance(user, AuthenticatedUser):
auth_token = auth_context_var.set(user)

token = request_ctx.set(
RequestContext(
message.request_id,
Expand Down Expand Up @@ -775,6 +786,8 @@ async def _handle_request(
response = types.ErrorData(code=0, message=str(err), data=None)
finally:
# Reset the global state after we are done
if auth_token is not None:
auth_context_var.reset(auth_token)
if token is not None: # pragma: no branch
request_ctx.reset(token)

Expand Down
69 changes: 69 additions & 0 deletions tests/server/lowlevel/test_auth_context_from_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from unittest.mock import AsyncMock, Mock

import pytest
from starlette.requests import Request
from starlette.types import Scope

import mcp.types as types
from mcp.server.auth.middleware.auth_context import auth_context_var, get_access_token
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
from mcp.server.lowlevel.server import Server
from mcp.server.session import ServerSession
from mcp.shared.message import ServerMessageMetadata
from mcp.shared.session import RequestResponder


@pytest.mark.anyio
async def test_handle_request_sets_auth_context_from_request() -> None:
server = Server("test-server")

@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="echo_access_token",
description="Return access token",
inputSchema={"type": "object", "properties": {}},
)
]

@server.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, object]) -> list[types.TextContent]:
assert name == "echo_access_token"
access_token = get_access_token()
token = access_token.token if access_token else ""
return [types.TextContent(type="text", text=token)]

access_token = AccessToken(token="test-token", client_id="client", scopes=["test"])
user = AuthenticatedUser(access_token)
headers: list[tuple[bytes, bytes]] = []
scope: Scope = {
"type": "http",
"method": "POST",
"path": "/mcp",
"headers": headers,
"user": user,
}
request = Request(scope)

message = Mock(spec=RequestResponder)
message.request_id = "req-1"
message.request_meta = None
message.message_metadata = ServerMessageMetadata(request_context=request)
message.respond = AsyncMock()

session = Mock(spec=ServerSession)
session.client_params = None

call_request = types.CallToolRequest(params=types.CallToolRequestParams(name="echo_access_token", arguments={}))

await server._handle_request(message, call_request, session, {}, raise_exceptions=False)

assert auth_context_var.get() is None
assert message.respond.called
response = message.respond.call_args.args[0]
assert isinstance(response.root, types.CallToolResult)
content = response.root.content[0]
assert isinstance(content, types.TextContent)
assert content.text == "test-token"
114 changes: 114 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from httpx_sse import ServerSentEvent
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.routing import Mount

Expand All @@ -32,6 +34,9 @@
streamablehttp_client, # pyright: ignore[reportDeprecated]
)
from mcp.server import Server
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
from mcp.server.auth.provider import AccessToken, TokenVerifier
from mcp.server.streamable_http import (
MCP_PROTOCOL_VERSION_HEADER,
MCP_SESSION_ID_HEADER,
Expand Down Expand Up @@ -1520,6 +1525,71 @@ def run_context_aware_server(port: int): # pragma: no cover
server_instance.run()


class AuthTokenServerTest(Server): # pragma: no cover
def __init__(self):
super().__init__("AuthTokenServer")

@self.list_tools()
async def handle_list_tools() -> list[Tool]:
return [
Tool(
name="echo_access_token",
description="Return the current access token",
inputSchema={"type": "object", "properties": {}},
)
]

@self.call_tool()
async def handle_call_tool(name: str, _args: dict[str, Any]) -> list[TextContent]:
assert name == "echo_access_token"
access_token = get_access_token()
assert access_token is not None
return [TextContent(type="text", text=access_token.token)]


def run_auth_token_server(port: int) -> None: # pragma: no cover
"""Run the auth token test server."""
server = AuthTokenServerTest()

class AcceptAllTokenVerifier(TokenVerifier):
async def verify_token(self, token: str) -> AccessToken | None:
return AccessToken(
token=token,
client_id="test-client",
scopes=["test"],
)

token_verifier = AcceptAllTokenVerifier()

session_manager = StreamableHTTPSessionManager(
app=server,
event_store=None,
json_response=False,
)

middleware = [
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(token_verifier)),
Middleware(AuthContextMiddleware),
]

app = Starlette(
debug=True,
routes=[Mount("/mcp", app=session_manager.handle_request)],
middleware=middleware,
lifespan=lambda app: session_manager.run(),
)

server_instance = uvicorn.Server(
config=uvicorn.Config(
app=app,
host="127.0.0.1",
port=port,
log_level="error",
)
)
server_instance.run()


@pytest.fixture
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
Expand All @@ -1537,6 +1607,22 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
print("Context-aware server process failed to terminate")


@pytest.fixture
def auth_token_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start the auth token server in a separate process."""
proc = multiprocessing.Process(target=run_auth_token_server, args=(basic_server_port,), daemon=True)
proc.start()

wait_for_server(basic_server_port)

yield

proc.kill()
proc.join(timeout=2)
if proc.is_alive(): # pragma: no cover
print("Auth token server process failed to terminate")


@pytest.mark.anyio
async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
"""Test that request context is properly propagated through StreamableHTTP."""
Expand Down Expand Up @@ -1571,6 +1657,34 @@ async def test_streamablehttp_request_context_propagation(context_aware_server:
assert headers_data.get("x-trace-id") == "trace-123"


@pytest.mark.anyio
async def test_streamablehttp_refreshes_access_token(auth_token_server: None, basic_server_url: str) -> None:
"""Ensure refreshed bearer tokens are used for subsequent requests."""
token_a = "token-a"
token_b = "token-b"

async with create_mcp_http_client(headers={"Authorization": f"Bearer {token_a}"}) as httpx_client:
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)

tool_result = await session.call_tool("echo_access_token", {})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == token_a

httpx_client.headers["Authorization"] = f"Bearer {token_b}"
tool_result = await session.call_tool("echo_access_token", {})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == token_b


@pytest.mark.anyio
async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
"""Test that request contexts are isolated between StreamableHTTP clients."""
Expand Down
Loading