| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- package jsonrpc2
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "sync"
- "testing"
- "time"
- "github.com/sirupsen/logrus"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- // ---------- helpers ----------
- func newTestRPCServer(t *testing.T) *RPCServer {
- t.Helper()
- logger := logrus.New()
- logger.SetOutput(io.Discard) // silence logs during tests
- srv, err := NewRPCServer("test-rpc", logger)
- require.NoError(t, err)
- srv.SetMaxBodySize(1024)
- return srv
- }
- func postJSON(h http.Handler, body string) *httptest.ResponseRecorder {
- req := httptest.NewRequest(
- http.MethodPost,
- "/rpc",
- bytes.NewBufferString(body),
- )
- req.Header.Set("Content-Type", "application/json")
- rec := httptest.NewRecorder()
- h.ServeHTTP(rec, req)
- return rec
- }
- func parseResponse(
- t *testing.T,
- rec *httptest.ResponseRecorder,
- ) map[string]any {
- t.Helper()
- var resp map[string]any
- err := json.Unmarshal(rec.Body.Bytes(), &resp)
- require.NoError(t, err)
- return resp
- }
- // ---------- basic ----------
- func TestPing(t *testing.T) {
- srv := newTestRPCServer(t)
- rec := postJSON(
- srv,
- `{"jsonrpc":"2.0","method":"ping","id":1}`,
- )
- assert.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- assert.Equal(t, "2.0", resp["jsonrpc"])
- assert.Equal(t, "pong", resp["result"])
- assert.Equal(t, float64(1), resp["id"])
- }
- func TestNotification(t *testing.T) {
- srv := newTestRPCServer(t)
- rec := postJSON(
- srv,
- `{"jsonrpc":"2.0","method":"ping"}`,
- )
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Empty(t, rec.Body.String())
- }
- func TestOnlyPOST(t *testing.T) {
- srv := newTestRPCServer(t)
- req := httptest.NewRequest(http.MethodGet, "/rpc", nil)
- rec := httptest.NewRecorder()
- srv.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
- }
- func TestMethodNotFound(t *testing.T) {
- srv := newTestRPCServer(t)
- rec := postJSON(
- srv,
- `{"jsonrpc":"2.0","method":"unknown","id":1}`,
- )
- assert.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- errObj, ok := resp["error"].(map[string]any)
- require.True(t, ok)
- assert.Equal(t, float64(-32601), errObj["code"])
- assert.Contains(t, errObj["message"], "Method not found")
- }
- func TestInvalidJSON(t *testing.T) {
- srv := newTestRPCServer(t)
- rec := postJSON(srv, `{invalid json}`)
- assert.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- errObj, ok := resp["error"].(map[string]any)
- require.True(t, ok)
- assert.Equal(t, float64(-32700), errObj["code"])
- }
- func TestBatchRejected(t *testing.T) {
- srv := newTestRPCServer(t)
- rec := postJSON(
- srv,
- `[{"jsonrpc":"2.0","method":"ping","id":1}]`,
- )
- assert.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- errObj, ok := resp["error"].(map[string]any)
- require.True(t, ok)
- assert.Equal(t, float64(-32600), errObj["code"])
- }
- func TestRequestBodyTooLarge(t *testing.T) {
- srv := newTestRPCServer(t)
- srv.SetMaxBodySize(10)
- rec := postJSON(
- srv,
- `{"jsonrpc":"2.0","method":"ping","id":1}`,
- )
- assert.Equal(t, http.StatusBadRequest, rec.Code)
- }
- // ---------- handler ----------
- func TestHandlerPanic(t *testing.T) {
- srv := newTestRPCServer(t)
- err := srv.RegisterMethods(MethodMap{
- "panic.method": func(
- ctx context.Context,
- req *Request,
- ) *Response {
- panic("boom")
- },
- })
- require.NoError(t, err)
- rec := postJSON(
- srv,
- `{"jsonrpc":"2.0","method":"panic.method","id":1}`,
- )
- assert.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- errObj, ok := resp["error"].(map[string]any)
- require.True(t, ok)
- assert.Equal(t, float64(-32603), errObj["code"])
- }
- func TestNilHandlerResponse(t *testing.T) {
- srv := newTestRPCServer(t)
- err := srv.RegisterMethods(MethodMap{
- "nil.response": func(
- ctx context.Context,
- req *Request,
- ) *Response {
- return nil
- },
- })
- require.NoError(t, err)
- rec := postJSON(
- srv,
- `{"jsonrpc":"2.0","method":"nil.response","id":1}`,
- )
- assert.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- assert.Nil(t, resp["result"])
- }
- // ---------- register ----------
- func TestRegisterMethodsAtomic(t *testing.T) {
- srv := newTestRPCServer(t)
- err := srv.RegisterMethods(MethodMap{
- "a": func(
- ctx context.Context,
- req *Request,
- ) *Response {
- return BuildResponse(req, "ok", nil)
- },
- })
- require.NoError(t, err)
- err = srv.RegisterMethods(MethodMap{
- "b": nil, // invalid
- })
- assert.Error(t, err)
- assert.NotContains(t, srv.ListMethods(), "b")
- }
- func TestListMethodsSorted(t *testing.T) {
- srv := newTestRPCServer(t)
- err := srv.RegisterMethods(MethodMap{
- "z": func(
- ctx context.Context,
- req *Request,
- ) *Response {
- return nil
- },
- "a": func(
- ctx context.Context,
- req *Request,
- ) *Response {
- return nil
- },
- })
- require.NoError(t, err)
- methods := srv.ListMethods()
- assert.Equal(
- t,
- []string{"a", "ping", "z"},
- methods,
- )
- }
- // ---------- concurrency ----------
- func TestConcurrentRequests(t *testing.T) {
- srv := newTestRPCServer(t)
- err := srv.RegisterMethods(MethodMap{
- "echo": func(
- ctx context.Context,
- req *Request,
- ) *Response {
- time.Sleep(10 * time.Millisecond)
- return BuildResponse(req, "ok", nil)
- },
- })
- require.NoError(t, err)
- const workers = 100
- var wg sync.WaitGroup
- errCh := make(chan error, workers)
- for i := range workers {
- wg.Add(1)
- go func(id int) {
- defer wg.Done()
- body := fmt.Sprintf(
- `{"jsonrpc":"2.0","method":"echo","id":%d}`,
- id,
- )
- rec := postJSON(srv, body)
- if rec.Code != http.StatusOK {
- errCh <- fmt.Errorf(
- "id=%d status=%d",
- id,
- rec.Code,
- )
- }
- }(i)
- }
- wg.Wait()
- close(errCh)
- for err := range errCh {
- t.Error(err)
- }
- }
- // ---------- lifecycle ----------
- func TestStopCancelsHandlers(t *testing.T) {
- srv := newTestRPCServer(t)
- done := make(chan struct{})
- err := srv.RegisterMethods(MethodMap{
- "block": func(
- ctx context.Context,
- req *Request,
- ) *Response {
- <-ctx.Done()
- close(done)
- return BuildResponse(
- req,
- "canceled",
- nil,
- )
- },
- })
- require.NoError(t, err)
- go func() {
- postJSON(
- srv,
- `{"jsonrpc":"2.0","method":"block","id":1}`,
- )
- }()
- time.Sleep(50 * time.Millisecond)
- srv.Stop()
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("handler was not canceled after Stop()")
- }
- }
|