|
|
@@ -0,0 +1,376 @@
|
|
|
+package jsonrpc2
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "context"
|
|
|
+ "encoding/json"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "net/http/httptest"
|
|
|
+ "sync"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/sirupsen/logrus"
|
|
|
+ "github.com/stretchr/testify/assert"
|
|
|
+ "github.com/stretchr/testify/require"
|
|
|
+)
|
|
|
+
|
|
|
+// ---------- helpers ----------
|
|
|
+
|
|
|
+func newTestRPCServer(t *testing.T) *RPCServer {
|
|
|
+ t.Helper()
|
|
|
+
|
|
|
+ logger := logrus.New()
|
|
|
+ logger.SetOutput(io.Discard) // silence logs during tests
|
|
|
+
|
|
|
+ srv, err := NewRPCServer("test-rpc", logger)
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ srv.SetMaxBodySize(1024)
|
|
|
+
|
|
|
+ return srv
|
|
|
+}
|
|
|
+
|
|
|
+func postJSON(h http.Handler, body string) *httptest.ResponseRecorder {
|
|
|
+ req := httptest.NewRequest(
|
|
|
+ http.MethodPost,
|
|
|
+ "/rpc",
|
|
|
+ bytes.NewBufferString(body),
|
|
|
+ )
|
|
|
+
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
+
|
|
|
+ rec := httptest.NewRecorder()
|
|
|
+ h.ServeHTTP(rec, req)
|
|
|
+
|
|
|
+ return rec
|
|
|
+}
|
|
|
+
|
|
|
+func parseResponse(
|
|
|
+ t *testing.T,
|
|
|
+ rec *httptest.ResponseRecorder,
|
|
|
+) map[string]any {
|
|
|
+ t.Helper()
|
|
|
+
|
|
|
+ var resp map[string]any
|
|
|
+
|
|
|
+ err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ return resp
|
|
|
+}
|
|
|
+
|
|
|
+// ---------- basic ----------
|
|
|
+
|
|
|
+func TestPing(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ rec := postJSON(
|
|
|
+ srv,
|
|
|
+ `{"jsonrpc":"2.0","method":"ping","id":1}`,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
+
|
|
|
+ resp := parseResponse(t, rec)
|
|
|
+
|
|
|
+ assert.Equal(t, "2.0", resp["jsonrpc"])
|
|
|
+ assert.Equal(t, "pong", resp["result"])
|
|
|
+ assert.Equal(t, float64(1), resp["id"])
|
|
|
+}
|
|
|
+
|
|
|
+func TestNotification(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ rec := postJSON(
|
|
|
+ srv,
|
|
|
+ `{"jsonrpc":"2.0","method":"ping"}`,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
+ assert.Empty(t, rec.Body.String())
|
|
|
+}
|
|
|
+
|
|
|
+func TestOnlyPOST(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ req := httptest.NewRequest(http.MethodGet, "/rpc", nil)
|
|
|
+ rec := httptest.NewRecorder()
|
|
|
+
|
|
|
+ srv.ServeHTTP(rec, req)
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
|
|
|
+}
|
|
|
+
|
|
|
+func TestMethodNotFound(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ rec := postJSON(
|
|
|
+ srv,
|
|
|
+ `{"jsonrpc":"2.0","method":"unknown","id":1}`,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
+
|
|
|
+ resp := parseResponse(t, rec)
|
|
|
+
|
|
|
+ errObj, ok := resp["error"].(map[string]any)
|
|
|
+ require.True(t, ok)
|
|
|
+
|
|
|
+ assert.Equal(t, float64(-32601), errObj["code"])
|
|
|
+ assert.Contains(t, errObj["message"], "Method not found")
|
|
|
+}
|
|
|
+
|
|
|
+func TestInvalidJSON(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ rec := postJSON(srv, `{invalid json}`)
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
+
|
|
|
+ resp := parseResponse(t, rec)
|
|
|
+
|
|
|
+ errObj, ok := resp["error"].(map[string]any)
|
|
|
+ require.True(t, ok)
|
|
|
+
|
|
|
+ assert.Equal(t, float64(-32700), errObj["code"])
|
|
|
+}
|
|
|
+
|
|
|
+func TestBatchRejected(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ rec := postJSON(
|
|
|
+ srv,
|
|
|
+ `[{"jsonrpc":"2.0","method":"ping","id":1}]`,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
+
|
|
|
+ resp := parseResponse(t, rec)
|
|
|
+
|
|
|
+ errObj, ok := resp["error"].(map[string]any)
|
|
|
+ require.True(t, ok)
|
|
|
+
|
|
|
+ assert.Equal(t, float64(-32600), errObj["code"])
|
|
|
+}
|
|
|
+
|
|
|
+func TestRequestBodyTooLarge(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ srv.SetMaxBodySize(10)
|
|
|
+
|
|
|
+ rec := postJSON(
|
|
|
+ srv,
|
|
|
+ `{"jsonrpc":"2.0","method":"ping","id":1}`,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusBadRequest, rec.Code)
|
|
|
+}
|
|
|
+
|
|
|
+// ---------- handler ----------
|
|
|
+
|
|
|
+func TestHandlerPanic(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ err := srv.RegisterMethods(MethodMap{
|
|
|
+ "panic.method": func(
|
|
|
+ ctx context.Context,
|
|
|
+ req *Request,
|
|
|
+ ) *Response {
|
|
|
+ panic("boom")
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ rec := postJSON(
|
|
|
+ srv,
|
|
|
+ `{"jsonrpc":"2.0","method":"panic.method","id":1}`,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
+
|
|
|
+ resp := parseResponse(t, rec)
|
|
|
+
|
|
|
+ errObj, ok := resp["error"].(map[string]any)
|
|
|
+ require.True(t, ok)
|
|
|
+
|
|
|
+ assert.Equal(t, float64(-32603), errObj["code"])
|
|
|
+}
|
|
|
+
|
|
|
+func TestNilHandlerResponse(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ err := srv.RegisterMethods(MethodMap{
|
|
|
+ "nil.response": func(
|
|
|
+ ctx context.Context,
|
|
|
+ req *Request,
|
|
|
+ ) *Response {
|
|
|
+ return nil
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ rec := postJSON(
|
|
|
+ srv,
|
|
|
+ `{"jsonrpc":"2.0","method":"nil.response","id":1}`,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
+
|
|
|
+ resp := parseResponse(t, rec)
|
|
|
+
|
|
|
+ assert.Nil(t, resp["result"])
|
|
|
+}
|
|
|
+
|
|
|
+// ---------- register ----------
|
|
|
+
|
|
|
+func TestRegisterMethodsAtomic(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ err := srv.RegisterMethods(MethodMap{
|
|
|
+ "a": func(
|
|
|
+ ctx context.Context,
|
|
|
+ req *Request,
|
|
|
+ ) *Response {
|
|
|
+ return BuildResponse(req, "ok", nil)
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ err = srv.RegisterMethods(MethodMap{
|
|
|
+ "b": nil, // invalid
|
|
|
+ })
|
|
|
+
|
|
|
+ assert.Error(t, err)
|
|
|
+ assert.NotContains(t, srv.ListMethods(), "b")
|
|
|
+}
|
|
|
+
|
|
|
+func TestListMethodsSorted(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ err := srv.RegisterMethods(MethodMap{
|
|
|
+ "z": func(
|
|
|
+ ctx context.Context,
|
|
|
+ req *Request,
|
|
|
+ ) *Response {
|
|
|
+ return nil
|
|
|
+ },
|
|
|
+ "a": func(
|
|
|
+ ctx context.Context,
|
|
|
+ req *Request,
|
|
|
+ ) *Response {
|
|
|
+ return nil
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ methods := srv.ListMethods()
|
|
|
+
|
|
|
+ assert.Equal(
|
|
|
+ t,
|
|
|
+ []string{"a", "ping", "z"},
|
|
|
+ methods,
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
+// ---------- concurrency ----------
|
|
|
+
|
|
|
+func TestConcurrentRequests(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ err := srv.RegisterMethods(MethodMap{
|
|
|
+ "echo": func(
|
|
|
+ ctx context.Context,
|
|
|
+ req *Request,
|
|
|
+ ) *Response {
|
|
|
+ time.Sleep(10 * time.Millisecond)
|
|
|
+ return BuildResponse(req, "ok", nil)
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ const workers = 100
|
|
|
+
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ errCh := make(chan error, workers)
|
|
|
+
|
|
|
+ for i := 0; i < workers; i++ {
|
|
|
+ wg.Add(1)
|
|
|
+
|
|
|
+ go func(id int) {
|
|
|
+ defer wg.Done()
|
|
|
+
|
|
|
+ body := fmt.Sprintf(
|
|
|
+ `{"jsonrpc":"2.0","method":"echo","id":%d}`,
|
|
|
+ id,
|
|
|
+ )
|
|
|
+
|
|
|
+ rec := postJSON(srv, body)
|
|
|
+
|
|
|
+ if rec.Code != http.StatusOK {
|
|
|
+ errCh <- fmt.Errorf(
|
|
|
+ "id=%d status=%d",
|
|
|
+ id,
|
|
|
+ rec.Code,
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }(i)
|
|
|
+ }
|
|
|
+
|
|
|
+ wg.Wait()
|
|
|
+ close(errCh)
|
|
|
+
|
|
|
+ for err := range errCh {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// ---------- lifecycle ----------
|
|
|
+
|
|
|
+func TestStopCancelsHandlers(t *testing.T) {
|
|
|
+ srv := newTestRPCServer(t)
|
|
|
+
|
|
|
+ done := make(chan struct{})
|
|
|
+
|
|
|
+ err := srv.RegisterMethods(MethodMap{
|
|
|
+ "block": func(
|
|
|
+ ctx context.Context,
|
|
|
+ req *Request,
|
|
|
+ ) *Response {
|
|
|
+ <-ctx.Done()
|
|
|
+ close(done)
|
|
|
+
|
|
|
+ return BuildResponse(
|
|
|
+ req,
|
|
|
+ "canceled",
|
|
|
+ nil,
|
|
|
+ )
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NoError(t, err)
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ postJSON(
|
|
|
+ srv,
|
|
|
+ `{"jsonrpc":"2.0","method":"block","id":1}`,
|
|
|
+ )
|
|
|
+ }()
|
|
|
+
|
|
|
+ time.Sleep(50 * time.Millisecond)
|
|
|
+
|
|
|
+ srv.Stop()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-done:
|
|
|
+ case <-time.After(time.Second):
|
|
|
+ t.Fatal("handler was not canceled after Stop()")
|
|
|
+ }
|
|
|
+}
|