Skip to content

refactor(client): improve validation and remove server methods #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

package io.modelcontextprotocol;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function;

import com.fasterxml.jackson.core.type.TypeReference;
Expand All @@ -14,47 +16,53 @@
import io.modelcontextprotocol.spec.ServerMcpTransport;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.core.scheduler.Schedulers;

@SuppressWarnings("unused")
/**
* A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport}
* interfaces.
*/
public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport {

private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
private final Sinks.Many<McpSchema.JSONRPCMessage> inbound = Sinks.many().unicast().onBackpressureBuffer();

private final List<McpSchema.JSONRPCMessage> sent = new ArrayList<>();

private final Sinks.Many<McpSchema.JSONRPCMessage> outgoing = Sinks.many().multicast().onBackpressureBuffer();
private final BiConsumer<MockMcpTransport, McpSchema.JSONRPCMessage> interceptor;

private final Sinks.Many<McpSchema.JSONRPCMessage> inbound = Sinks.many().unicast().onBackpressureBuffer();
public MockMcpTransport() {
this((t, msg) -> {
});
}

private final Flux<McpSchema.JSONRPCMessage> outboundView = outgoing.asFlux().cache(1);
public MockMcpTransport(BiConsumer<MockMcpTransport, McpSchema.JSONRPCMessage> interceptor) {
this.interceptor = interceptor;
}

public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) {
if (inbound.tryEmitNext(message).isFailure()) {
throw new RuntimeException("Failed to emit message " + message);
throw new RuntimeException("Failed to process incoming message " + message);
}
inboundMessageCount.incrementAndGet();
}

@Override
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
if (outgoing.tryEmitNext(message).isFailure()) {
return Mono.error(new RuntimeException("Can't emit outgoing message " + message));
}
sent.add(message);
interceptor.accept(this, message);
return Mono.empty();
}

public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() {
return (JSONRPCRequest) outboundView.blockFirst();
return (JSONRPCRequest) getLastSentMessage();
}

public McpSchema.JSONRPCNotification getLastSentMessageAsNotifiation() {
return (JSONRPCNotification) outboundView.blockFirst();
public McpSchema.JSONRPCNotification getLastSentMessageAsNotification() {
return (JSONRPCNotification) getLastSentMessage();
}

public McpSchema.JSONRPCMessage getLastSentMessage() {
return outboundView.blockFirst();
return !sent.isEmpty() ? sent.get(sent.size() - 1) : null;
}

private volatile boolean connected = false;
Expand All @@ -66,7 +74,6 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
}
connected = true;
return inbound.asFlux()
.publishOn(Schedulers.boundedElastic())
.flatMap(message -> Mono.just(message).transform(handler))
.doFinally(signal -> connected = false)
.then();
Expand All @@ -76,8 +83,8 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
public Mono<Void> closeGracefully() {
return Mono.defer(() -> {
connected = false;
outgoing.tryEmitComplete();
inbound.tryEmitComplete();
// Wait for all subscribers to complete
return Mono.empty();
});
}
Expand All @@ -87,4 +94,4 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return new ObjectMapper().convertValue(data, typeRef);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,6 @@ void testNotificationHandlers() {

assertThatCode(() -> {
client.initialize().block();
// Trigger notifications
client.sendResourcesListChanged().block();
client.promptListChangedNotification().block();
client.closeGracefully().block();
}).doesNotThrowAnyException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,6 @@ void testNotificationHandlers() {

assertThatCode(() -> {
client.initialize();
// Trigger notifications
client.sendResourcesListChanged();
client.promptListChangedNotification();
client.close();
}).doesNotThrowAnyException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ public McpSchema.Implementation getServerInfo() {
return this.serverInfo;
}

/**
* Check if the client-server connection is initialized.
* @return true if the client-server connection is initialized
*/
public boolean isInitialized() {
return this.serverCapabilities != null;
}

/**
* Get the client capabilities that define the supported features and functionality.
* @return The client capabilities
Expand Down Expand Up @@ -456,6 +464,12 @@ private RequestHandler<CreateMessageResult> samplingCreateMessageHandler() {
* (false/absent)
*/
public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before calling tools"));
}
if (this.serverCapabilities.tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF);
}

Expand All @@ -477,6 +491,12 @@ public Mono<McpSchema.ListToolsResult> listTools() {
* Optional cursor for pagination if more tools are available
*/
public Mono<McpSchema.ListToolsResult> listTools(String cursor) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before listing tools"));
}
if (this.serverCapabilities.tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor),
LIST_TOOLS_RESULT_TYPE_REF);
}
Expand Down Expand Up @@ -532,6 +552,12 @@ public Mono<McpSchema.ListResourcesResult> listResources() {
* @return A Mono that completes with the list of resources result
*/
public Mono<McpSchema.ListResourcesResult> listResources(String cursor) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before listing resources"));
}
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor),
LIST_RESOURCES_RESULT_TYPE_REF);
}
Expand All @@ -551,6 +577,12 @@ public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.Resource resour
* @return A Mono that completes with the resource content
*/
public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.ReadResourceRequest readResourceRequest) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before reading resources"));
}
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest,
READ_RESOURCE_RESULT_TYPE_REF);
}
Expand All @@ -575,19 +607,16 @@ public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates() {
* @return A Mono that completes with the list of resource templates result
*/
public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String cursor) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before listing resource templates"));
}
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST,
new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF);
}

/**
* List Changed Notification. When the list of available resources changes, servers
* that declared the listChanged capability SHOULD send a notification.
* @return A Mono that completes when the notification is sent
*/
public Mono<Void> sendResourcesListChanged() {
return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED);
}

/**
* Subscriptions. The protocol supports optional subscriptions to resource changes.
* Clients can subscribe to specific resources and receive notifications when they
Expand Down Expand Up @@ -660,16 +689,6 @@ public Mono<GetPromptResult> getPrompt(GetPromptRequest getPromptRequest) {
return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF);
}

/**
* (Server) An optional notification from the server to the client, informing it that
* the list of prompts it offers has changed. This may be issued by servers without
* any previous subscription from the client.
* @return A Mono that completes when the notification is sent
*/
public Mono<Void> promptListChangedNotification() {
return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED);
}

private NotificationHandler asyncPromptsChangeNotificationHandler(
List<Function<List<McpSchema.Prompt>, Mono<Void>>> promptsChangeConsumers) {
return params -> listPrompts().flatMap(listPromptsResult -> Flux.fromIterable(promptsChangeConsumers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,6 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() {
return this.delegate.listResourceTemplates().block();
}

/**
* List Changed Notification. When the list of available resources changes, servers
* that declared the listChanged capability SHOULD send a notification:
*/
public void sendResourcesListChanged() {
this.delegate.sendResourcesListChanged().block();
}

/**
* Subscriptions. The protocol supports optional subscriptions to resource changes.
* Clients can subscribe to specific resources and receive notifications when they
Expand Down Expand Up @@ -329,15 +321,6 @@ public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) {
return this.delegate.getPrompt(getPromptRequest).block();
}

/**
* (Server) An optional notification from the server to the client, informing it that
* the list of prompts it offers has changed. This may be issued by servers without
* any previous subscription from the client.
*/
public void promptListChangedNotification() {
this.delegate.promptListChangedNotification().block();
}

/**
* Client can set the minimum logging level it wants to receive from the server.
* @param loggingLevel the min logging level
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,13 @@ private DefaultMcpSession.RequestHandler<McpSchema.InitializeResult> asyncInitia
initializeRequest.protocolVersion(), initializeRequest.capabilities(),
initializeRequest.clientInfo());

// The server MUST respond with the highest protocol version it supports if
// it does not support the requested (e.g. Client) version.
String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);

if (this.protocolVersions.contains(initializeRequest.protocolVersion())) {
// If the server supports the requested protocol version, it MUST respond
// with the same version.
serverProtocolVersion = initializeRequest.protocolVersion();
}
else {
Expand Down
39 changes: 21 additions & 18 deletions mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

package io.modelcontextprotocol;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function;

import com.fasterxml.jackson.core.type.TypeReference;
Expand All @@ -14,50 +16,53 @@
import io.modelcontextprotocol.spec.ServerMcpTransport;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.core.scheduler.Schedulers;

/**
* A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport}
* interfaces.
*/
public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport {

private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
private final Sinks.Many<McpSchema.JSONRPCMessage> inbound = Sinks.many().unicast().onBackpressureBuffer();

private final Sinks.Many<McpSchema.JSONRPCMessage> outgoing = Sinks.many().multicast().onBackpressureBuffer();
private final List<McpSchema.JSONRPCMessage> sent = new ArrayList<>();

private final Sinks.Many<McpSchema.JSONRPCMessage> inbound = Sinks.many().unicast().onBackpressureBuffer();
private final BiConsumer<MockMcpTransport, McpSchema.JSONRPCMessage> interceptor;

private final Flux<McpSchema.JSONRPCMessage> outboundView = outgoing.asFlux().cache(1);
public MockMcpTransport() {
this((t, msg) -> {
});
}

public MockMcpTransport(BiConsumer<MockMcpTransport, McpSchema.JSONRPCMessage> interceptor) {
this.interceptor = interceptor;
}

public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) {
if (inbound.tryEmitNext(message).isFailure()) {
throw new RuntimeException("Failed to emit message " + message);
throw new RuntimeException("Failed to process incoming message " + message);
}
inboundMessageCount.incrementAndGet();
}

@Override
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
if (outgoing.tryEmitNext(message).isFailure()) {
return Mono.error(new RuntimeException("Can't emit outgoing message " + message));
}
sent.add(message);
interceptor.accept(this, message);
return Mono.empty();
}

public McpSchema.JSONRPCRequest getLastSentMessageAsRequest() {
return (JSONRPCRequest) outboundView.blockFirst();
return (JSONRPCRequest) getLastSentMessage();
}

public McpSchema.JSONRPCNotification getLastSentMessageAsNotifiation() {
return (JSONRPCNotification) outboundView.blockFirst();
public McpSchema.JSONRPCNotification getLastSentMessageAsNotification() {
return (JSONRPCNotification) getLastSentMessage();
}

public McpSchema.JSONRPCMessage getLastSentMessage() {
return outboundView.blockFirst();
return !sent.isEmpty() ? sent.get(sent.size() - 1) : null;
}

private volatile boolean connected = false;
Expand All @@ -69,7 +74,6 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
}
connected = true;
return inbound.asFlux()
.publishOn(Schedulers.boundedElastic())
.flatMap(message -> Mono.just(message).transform(handler))
.doFinally(signal -> connected = false)
.then();
Expand All @@ -79,7 +83,6 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
public Mono<Void> closeGracefully() {
return Mono.defer(() -> {
connected = false;
outgoing.tryEmitComplete();
inbound.tryEmitComplete();
// Wait for all subscribers to complete
return Mono.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,6 @@ void testNotificationHandlers() {

assertThatCode(() -> {
client.initialize().block();
// Trigger notifications
client.sendResourcesListChanged().block();
client.promptListChangedNotification().block();
client.closeGracefully().block();
}).doesNotThrowAnyException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,6 @@ void testNotificationHandlers() {

assertThatCode(() -> {
client.initialize();
// Trigger notifications
client.sendResourcesListChanged();
client.promptListChangedNotification();
client.close();
}).doesNotThrowAnyException();
}
Expand Down
Loading