diff --git a/context.go b/context.go index 9c1eda8..baae996 100644 --- a/context.go +++ b/context.go @@ -7,6 +7,7 @@ import ( ) type requestIDKey struct{} +type metadataIDKey struct{} type methodNameKey struct{} // RequestID takes request id from context. @@ -19,6 +20,16 @@ func WithRequestID(c context.Context, id *fastjson.RawMessage) context.Context { return context.WithValue(c, requestIDKey{}, id) } +// GetMetadata takes jsonrpc metadata from context. +func GetMetadata(c context.Context) Metadata { + return c.Value(metadataIDKey{}).(Metadata) +} + +// WithMetadata adds jsonrpc metadata to context. +func WithMetadata(c context.Context, md Metadata) context.Context { + return context.WithValue(c, metadataIDKey{}, md) +} + // MethodName takes method name from context. func MethodName(c context.Context) string { return c.Value(methodNameKey{}).(string) diff --git a/context_test.go b/context_test.go index 3ae869e..0a3653b 100644 --- a/context_test.go +++ b/context_test.go @@ -20,6 +20,18 @@ func TestRequestID(t *testing.T) { require.Equal(t, &id, pick) } +func TestMetadata(t *testing.T) { + + c := context.Background() + md := Metadata{Params: Metadata{}} + c = WithMetadata(c, md) + var pick Metadata + require.NotPanics(t, func() { + pick = GetMetadata(c) + }) + require.Equal(t, md, pick) +} + func TestMethodName(t *testing.T) { c := context.Background() diff --git a/handler.go b/handler.go index adfe778..621a0f1 100644 --- a/handler.go +++ b/handler.go @@ -55,16 +55,17 @@ func (mr *MethodRepository) ServeHTTP(w http.ResponseWriter, r *http.Request) { // InvokeMethod invokes JSON-RPC method. func (mr *MethodRepository) InvokeMethod(c context.Context, r *Request) *Response { - var h Handler + var md Metadata res := NewResponse(r) - h, res.Error = mr.TakeMethod(r) + md, res.Error = mr.TakeMethodMetadata(r) if res.Error != nil { return res } wrappedContext := WithRequestID(c, r.ID) wrappedContext = WithMethodName(wrappedContext, r.Method) - res.Result, res.Error = h.ServeJSONRPC(wrappedContext, r.Params) + wrappedContext = WithMetadata(wrappedContext, md) + res.Result, res.Error = md.Handler.ServeJSONRPC(wrappedContext, r.Params) if res.Error != nil { res.Result = nil } diff --git a/method.go b/method.go index fcc8961..fe0125a 100644 --- a/method.go +++ b/method.go @@ -27,19 +27,29 @@ func NewMethodRepository() *MethodRepository { } } -// TakeMethod takes jsonrpc.Func in MethodRepository. -func (mr *MethodRepository) TakeMethod(r *Request) (Handler, *Error) { +// TakeMethodMetadata takes metadata in MethodRepository for request. +func (mr *MethodRepository) TakeMethodMetadata(r *Request) (Metadata, *Error) { + if r.Method == "" || r.Version != Version { - return nil, ErrInvalidParams() + return Metadata{}, ErrInvalidParams() } mr.m.RLock() md, ok := mr.r[r.Method] mr.m.RUnlock() if !ok { - return nil, ErrMethodNotFound() + return Metadata{}, ErrMethodNotFound() } + return md, nil +} + +// TakeMethod takes jsonrpc.Func in MethodRepository. +func (mr *MethodRepository) TakeMethod(r *Request) (Handler, *Error) { + md, err := mr.TakeMethodMetadata(r) + if err != nil { + return nil, err + } return md.Handler, nil }