diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..93cb10d0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -2,6 +2,7 @@ import java.io.IOException; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; @@ -261,8 +262,8 @@ private Mono handleSseConnection(ServerRequest request) { .body(Flux.>create(sink -> { WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); - McpServerSession session = sessionFactory.create(sessionTransport); - String sessionId = session.getId(); + String sessionId = sessionFactory.generateId(); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); logger.debug("Created new SSE connection for session: {}", sessionId); sessions.put(sessionId, session); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..879129f3 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -247,7 +247,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - String sessionId = UUID.randomUUID().toString(); + String sessionId = sessionFactory.generateId(); logger.debug("Creating new SSE connection for session: {}", sessionId); // Send initial endpoint event @@ -263,7 +263,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { }); WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); this.sessions.put(sessionId, session); try { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 1efa13de..3f9bef29 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -183,9 +183,8 @@ public class McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory((id, transport) -> new McpServerSession(id, requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff47..f8ea97a4 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -208,7 +208,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) response.setHeader("Connection", "keep-alive"); response.setHeader("Access-Control-Allow-Origin", "*"); - String sessionId = UUID.randomUUID().toString(); + String sessionId = sessionFactory.generateId(); AsyncContext asyncContext = request.startAsync(); asyncContext.setTimeout(0); @@ -219,7 +219,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) writer); // Create a new session using the session factory - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionId, sessionTransport); this.sessions.put(sessionId, session); // Send initial endpoint event diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da977..4c5015fb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -12,6 +12,7 @@ import java.io.Reader; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.UUID; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -94,7 +95,8 @@ public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream input public void setSessionFactory(McpServerSession.Factory sessionFactory) { // Create a single session for the stdio connection var transport = new StdioMcpSessionTransport(); - this.session = sessionFactory.create(transport); + String sessionId = sessionFactory.generateId(); + this.session = sessionFactory.create(sessionId, transport); transport.initProcessing(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d85..e9d043a6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -2,6 +2,7 @@ import java.time.Duration; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -343,10 +344,19 @@ public interface Factory { /** * Creates a new 1:1 representation of the client-server interaction. + * @param sessionId the id of the session. * @param sessionTransport the transport to use for communication with the client. * @return a new server session. */ - McpServerSession create(McpServerTransport sessionTransport); + McpServerSession create(String sessionId, McpServerTransport sessionTransport); + + /** + * Generates a unique session id. + * @return a unique session id. + */ + default String generateId() { + return UUID.randomUUID().toString(); + } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf..5369118a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -16,6 +16,7 @@ package io.modelcontextprotocol; import java.util.Map; +import java.util.UUID; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -43,7 +44,8 @@ public MockMcpServerTransport getTransport() { @Override public void setSessionFactory(Factory sessionFactory) { - session = sessionFactory.create(transport); + String sessionId = sessionFactory.generateId(); + session = sessionFactory.create(sessionId, transport); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5a..5d33fed7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -10,6 +10,7 @@ import java.io.PrintStream; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -71,7 +72,7 @@ void setUp() { sessionFactory = mock(McpServerSession.Factory.class); // Configure mock behavior - when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(sessionFactory.create(any(), any(McpServerTransport.class))).thenReturn(mockSession); when(mockSession.closeGracefully()).thenReturn(Mono.empty()); when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); @@ -110,7 +111,7 @@ void shouldHandleIncomingMessages() throws Exception { AtomicReference capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); - McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession.Factory realSessionFactory = (id, transport) -> { McpServerSession session = mock(McpServerSession.class); when(session.handle(any())).thenAnswer(invocation -> { capturedMessage.set(invocation.getArgument(0));