package jsonrpc2 import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "sync" "testing" "time" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // ---------- helpers ---------- func newTestRPCServer(t *testing.T) *RPCServer { t.Helper() logger := logrus.New() logger.SetOutput(io.Discard) // silence logs during tests srv, err := NewRPCServer("test-rpc", logger) require.NoError(t, err) srv.SetMaxBodySize(1024) return srv } func postJSON(h http.Handler, body string) *httptest.ResponseRecorder { req := httptest.NewRequest( http.MethodPost, "/rpc", bytes.NewBufferString(body), ) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) return rec } func parseResponse( t *testing.T, rec *httptest.ResponseRecorder, ) map[string]any { t.Helper() var resp map[string]any err := json.Unmarshal(rec.Body.Bytes(), &resp) require.NoError(t, err) return resp } // ---------- basic ---------- func TestPing(t *testing.T) { srv := newTestRPCServer(t) rec := postJSON( srv, `{"jsonrpc":"2.0","method":"ping","id":1}`, ) assert.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) assert.Equal(t, "2.0", resp["jsonrpc"]) assert.Equal(t, "pong", resp["result"]) assert.Equal(t, float64(1), resp["id"]) } func TestNotification(t *testing.T) { srv := newTestRPCServer(t) rec := postJSON( srv, `{"jsonrpc":"2.0","method":"ping"}`, ) assert.Equal(t, http.StatusOK, rec.Code) assert.Empty(t, rec.Body.String()) } func TestOnlyPOST(t *testing.T) { srv := newTestRPCServer(t) req := httptest.NewRequest(http.MethodGet, "/rpc", nil) rec := httptest.NewRecorder() srv.ServeHTTP(rec, req) assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) } func TestMethodNotFound(t *testing.T) { srv := newTestRPCServer(t) rec := postJSON( srv, `{"jsonrpc":"2.0","method":"unknown","id":1}`, ) assert.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) errObj, ok := resp["error"].(map[string]any) require.True(t, ok) assert.Equal(t, float64(-32601), errObj["code"]) assert.Contains(t, errObj["message"], "Method not found") } func TestInvalidJSON(t *testing.T) { srv := newTestRPCServer(t) rec := postJSON(srv, `{invalid json}`) assert.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) errObj, ok := resp["error"].(map[string]any) require.True(t, ok) assert.Equal(t, float64(-32700), errObj["code"]) } func TestBatchRejected(t *testing.T) { srv := newTestRPCServer(t) rec := postJSON( srv, `[{"jsonrpc":"2.0","method":"ping","id":1}]`, ) assert.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) errObj, ok := resp["error"].(map[string]any) require.True(t, ok) assert.Equal(t, float64(-32600), errObj["code"]) } func TestRequestBodyTooLarge(t *testing.T) { srv := newTestRPCServer(t) srv.SetMaxBodySize(10) rec := postJSON( srv, `{"jsonrpc":"2.0","method":"ping","id":1}`, ) assert.Equal(t, http.StatusBadRequest, rec.Code) } // ---------- handler ---------- func TestHandlerPanic(t *testing.T) { srv := newTestRPCServer(t) err := srv.RegisterMethods(MethodMap{ "panic.method": func( ctx context.Context, req *Request, ) *Response { panic("boom") }, }) require.NoError(t, err) rec := postJSON( srv, `{"jsonrpc":"2.0","method":"panic.method","id":1}`, ) assert.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) errObj, ok := resp["error"].(map[string]any) require.True(t, ok) assert.Equal(t, float64(-32603), errObj["code"]) } func TestNilHandlerResponse(t *testing.T) { srv := newTestRPCServer(t) err := srv.RegisterMethods(MethodMap{ "nil.response": func( ctx context.Context, req *Request, ) *Response { return nil }, }) require.NoError(t, err) rec := postJSON( srv, `{"jsonrpc":"2.0","method":"nil.response","id":1}`, ) assert.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) assert.Nil(t, resp["result"]) } // ---------- register ---------- func TestRegisterMethodsAtomic(t *testing.T) { srv := newTestRPCServer(t) err := srv.RegisterMethods(MethodMap{ "a": func( ctx context.Context, req *Request, ) *Response { return BuildResponse(req, "ok", nil) }, }) require.NoError(t, err) err = srv.RegisterMethods(MethodMap{ "b": nil, // invalid }) assert.Error(t, err) assert.NotContains(t, srv.ListMethods(), "b") } func TestListMethodsSorted(t *testing.T) { srv := newTestRPCServer(t) err := srv.RegisterMethods(MethodMap{ "z": func( ctx context.Context, req *Request, ) *Response { return nil }, "a": func( ctx context.Context, req *Request, ) *Response { return nil }, }) require.NoError(t, err) methods := srv.ListMethods() assert.Equal( t, []string{"a", "ping", "z"}, methods, ) } // ---------- concurrency ---------- func TestConcurrentRequests(t *testing.T) { srv := newTestRPCServer(t) err := srv.RegisterMethods(MethodMap{ "echo": func( ctx context.Context, req *Request, ) *Response { time.Sleep(10 * time.Millisecond) return BuildResponse(req, "ok", nil) }, }) require.NoError(t, err) const workers = 100 var wg sync.WaitGroup errCh := make(chan error, workers) for i := 0; i < workers; i++ { wg.Add(1) go func(id int) { defer wg.Done() body := fmt.Sprintf( `{"jsonrpc":"2.0","method":"echo","id":%d}`, id, ) rec := postJSON(srv, body) if rec.Code != http.StatusOK { errCh <- fmt.Errorf( "id=%d status=%d", id, rec.Code, ) } }(i) } wg.Wait() close(errCh) for err := range errCh { t.Error(err) } } // ---------- lifecycle ---------- func TestStopCancelsHandlers(t *testing.T) { srv := newTestRPCServer(t) done := make(chan struct{}) err := srv.RegisterMethods(MethodMap{ "block": func( ctx context.Context, req *Request, ) *Response { <-ctx.Done() close(done) return BuildResponse( req, "canceled", nil, ) }, }) require.NoError(t, err) go func() { postJSON( srv, `{"jsonrpc":"2.0","method":"block","id":1}`, ) }() time.Sleep(50 * time.Millisecond) srv.Stop() select { case <-done: case <-time.After(time.Second): t.Fatal("handler was not canceled after Stop()") } }