Browse Source

新增jsonrpc2.0基础服务端和客户端单元测试代码

niujiuru 1 day ago
parent
commit
a2379995b4
4 changed files with 602 additions and 0 deletions
  1. 4 0
      go.mod
  2. 1 0
      go.sum
  3. 221 0
      utils/jsonrpc2/client_test.go
  4. 376 0
      utils/jsonrpc2/server_test.go

+ 4 - 0
go.mod

@@ -12,6 +12,7 @@ require (
 	github.com/mattn/go-shellwords v1.0.12
 	github.com/peterh/liner v1.2.2
 	github.com/sirupsen/logrus v1.9.3
+	github.com/stretchr/testify v1.11.1
 	github.com/vishvananda/netlink v1.3.1
 	golang.org/x/sys v0.36.0
 	gopkg.in/ini.v1 v1.67.0
@@ -19,13 +20,16 @@ require (
 )
 
 require (
+	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/gorilla/websocket v1.5.3 // indirect
 	github.com/hashicorp/errwrap v1.0.0 // indirect
 	github.com/hashicorp/go-multierror v1.1.1 // indirect
 	github.com/mattn/go-runewidth v0.0.3 // indirect
+	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/vishvananda/netns v0.0.5 // indirect
 	golang.org/x/net v0.44.0 // indirect
 	golang.org/x/sync v0.17.0 // indirect
+	gopkg.in/yaml.v3 v3.0.1 // indirect
 )
 
 replace gopkg.in/natefinch/lumberjack.v2 => ./vendor_lumberjack/lumberjack-2.2.1

+ 1 - 0
go.sum

@@ -49,6 +49,7 @@ golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
 golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
 gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=

+ 221 - 0
utils/jsonrpc2/client_test.go

@@ -0,0 +1,221 @@
+package jsonrpc2
+
+import (
+	"context"
+	"encoding/json"
+	"io"
+	"math"
+	"net/http"
+	"net/http/httptest"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+// ---------- helpers ----------
+
+func newTestHTTPServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
+	t.Helper()
+	return httptest.NewServer(handler)
+}
+
+func echoHandler(t *testing.T) http.HandlerFunc {
+	return func(w http.ResponseWriter, r *http.Request) {
+		body, err := io.ReadAll(r.Body)
+		require.NoError(t, err)
+
+		var req Request
+		err = json.Unmarshal(body, &req)
+		require.NoError(t, err)
+
+		if req.IsNotification() {
+			w.WriteHeader(http.StatusOK)
+			return
+		}
+
+		resp := BuildResponse(&req, "ok", nil)
+		w.Header().Set("Content-Type", "application/json")
+		_ = json.NewEncoder(w).Encode(resp)
+	}
+}
+
+func errorHandler(code int) http.HandlerFunc {
+	return func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(code)
+	}
+}
+
+// ---------- tests ----------
+
+func TestNewRPCClient(t *testing.T) {
+	_, err := NewRPCClient("")
+	assert.Error(t, err)
+
+	_, err = NewRPCClient("http://[::1]:namedport")
+	assert.Error(t, err)
+
+	c, err := NewRPCClient("http://localhost")
+	require.NoError(t, err)
+	assert.NotNil(t, c)
+}
+
+func TestCallSuccess(t *testing.T) {
+	srv := newTestHTTPServer(t, echoHandler(t))
+	defer srv.Close()
+
+	c, err := NewRPCClient(srv.URL)
+	require.NoError(t, err)
+
+	resp, err := c.Call(context.Background(), "test", []int{1, 2})
+	require.NoError(t, err)
+
+	var result string
+	err = json.Unmarshal(resp.Result, &result)
+	require.NoError(t, err)
+
+	assert.Equal(t, "ok", result)
+}
+
+func TestCallTimeout(t *testing.T) {
+	slowHandler := func(w http.ResponseWriter, r *http.Request) {
+		time.Sleep(2 * time.Second)
+		w.WriteHeader(http.StatusOK)
+	}
+
+	srv := newTestHTTPServer(t, slowHandler)
+	defer srv.Close()
+
+	c, err := NewRPCClient(
+		srv.URL,
+		WithHTTPClient(&http.Client{
+			Timeout: 500 * time.Millisecond,
+		}),
+	)
+	require.NoError(t, err)
+
+	_, err = c.Call(context.Background(), "slow", nil)
+
+	assert.Error(t, err)
+	assert.ErrorIs(t, err, context.DeadlineExceeded)
+}
+
+func TestCallHTTPError(t *testing.T) {
+	c, err := NewRPCClient("http://127.0.0.1:0")
+	require.NoError(t, err)
+
+	_, err = c.Call(context.Background(), "test", nil)
+	assert.Error(t, err)
+}
+
+func TestCallNon200Status(t *testing.T) {
+	srv := newTestHTTPServer(t, errorHandler(500))
+	defer srv.Close()
+
+	c, err := NewRPCClient(srv.URL)
+	require.NoError(t, err)
+
+	_, err = c.Call(context.Background(), "test", nil)
+	assert.ErrorContains(t, err, "unexpected status code")
+}
+
+func TestNotify(t *testing.T) {
+	srv := newTestHTTPServer(t, echoHandler(t))
+	defer srv.Close()
+
+	c, err := NewRPCClient(srv.URL)
+	require.NoError(t, err)
+
+	err = c.Notify(context.Background(), "notify", nil)
+	assert.NoError(t, err)
+}
+
+func TestNotifyContextCancel(t *testing.T) {
+	srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
+		<-r.Context().Done()
+	})
+	defer srv.Close()
+
+	c, err := NewRPCClient(srv.URL)
+	require.NoError(t, err)
+
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+
+	err = c.Notify(ctx, "notify", nil)
+	assert.Error(t, err)
+}
+
+func TestCallContextCancel(t *testing.T) {
+	srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
+		<-r.Context().Done()
+	})
+	defer srv.Close()
+
+	c, err := NewRPCClient(srv.URL)
+	require.NoError(t, err)
+
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+
+	_, err = c.Call(ctx, "test", nil)
+	assert.Error(t, err)
+}
+
+func TestClientHeaders(t *testing.T) {
+	var gotHeader http.Header
+
+	srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
+		gotHeader = r.Header.Clone()
+		w.WriteHeader(http.StatusOK)
+	})
+	defer srv.Close()
+
+	c, err := NewRPCClient(
+		srv.URL,
+		SetHeader("Authorization", "Bearer token"),
+		SetHeader("X-Request-ID", "123"),
+	)
+	require.NoError(t, err)
+
+	_, _ = c.Call(context.Background(), "test", nil)
+
+	assert.Equal(t, "Bearer token", gotHeader.Get("Authorization"))
+	assert.Equal(t, "123", gotHeader.Get("X-Request-ID"))
+	assert.Equal(t, "application/json", gotHeader.Get("Content-Type"))
+}
+
+func TestNextIDWrapAround(t *testing.T) {
+	c := &RPCClient{
+		seq: atomic.Int64{},
+	}
+
+	c.seq.Store(math.MaxInt32)
+
+	id1 := c.nextID()
+	id2 := c.nextID()
+
+	assert.Equal(t, int64(1), id1)
+	assert.Equal(t, int64(2), id2)
+}
+
+func TestCustomHTTPClient(t *testing.T) {
+	srv := newTestHTTPServer(t, echoHandler(t))
+	defer srv.Close()
+
+	hc := &http.Client{
+		Timeout: 500 * time.Millisecond,
+	}
+
+	c, err := NewRPCClient(
+		srv.URL,
+		WithHTTPClient(hc),
+	)
+	require.NoError(t, err)
+
+	resp, err := c.Call(context.Background(), "test", nil)
+	require.NoError(t, err)
+	assert.NotNil(t, resp)
+}

+ 376 - 0
utils/jsonrpc2/server_test.go

@@ -0,0 +1,376 @@
+package jsonrpc2
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/sirupsen/logrus"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+// ---------- helpers ----------
+
+func newTestRPCServer(t *testing.T) *RPCServer {
+	t.Helper()
+
+	logger := logrus.New()
+	logger.SetOutput(io.Discard) // silence logs during tests
+
+	srv, err := NewRPCServer("test-rpc", logger)
+	require.NoError(t, err)
+
+	srv.SetMaxBodySize(1024)
+
+	return srv
+}
+
+func postJSON(h http.Handler, body string) *httptest.ResponseRecorder {
+	req := httptest.NewRequest(
+		http.MethodPost,
+		"/rpc",
+		bytes.NewBufferString(body),
+	)
+
+	req.Header.Set("Content-Type", "application/json")
+
+	rec := httptest.NewRecorder()
+	h.ServeHTTP(rec, req)
+
+	return rec
+}
+
+func parseResponse(
+	t *testing.T,
+	rec *httptest.ResponseRecorder,
+) map[string]any {
+	t.Helper()
+
+	var resp map[string]any
+
+	err := json.Unmarshal(rec.Body.Bytes(), &resp)
+	require.NoError(t, err)
+
+	return resp
+}
+
+// ---------- basic ----------
+
+func TestPing(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	rec := postJSON(
+		srv,
+		`{"jsonrpc":"2.0","method":"ping","id":1}`,
+	)
+
+	assert.Equal(t, http.StatusOK, rec.Code)
+
+	resp := parseResponse(t, rec)
+
+	assert.Equal(t, "2.0", resp["jsonrpc"])
+	assert.Equal(t, "pong", resp["result"])
+	assert.Equal(t, float64(1), resp["id"])
+}
+
+func TestNotification(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	rec := postJSON(
+		srv,
+		`{"jsonrpc":"2.0","method":"ping"}`,
+	)
+
+	assert.Equal(t, http.StatusOK, rec.Code)
+	assert.Empty(t, rec.Body.String())
+}
+
+func TestOnlyPOST(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	req := httptest.NewRequest(http.MethodGet, "/rpc", nil)
+	rec := httptest.NewRecorder()
+
+	srv.ServeHTTP(rec, req)
+
+	assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
+}
+
+func TestMethodNotFound(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	rec := postJSON(
+		srv,
+		`{"jsonrpc":"2.0","method":"unknown","id":1}`,
+	)
+
+	assert.Equal(t, http.StatusOK, rec.Code)
+
+	resp := parseResponse(t, rec)
+
+	errObj, ok := resp["error"].(map[string]any)
+	require.True(t, ok)
+
+	assert.Equal(t, float64(-32601), errObj["code"])
+	assert.Contains(t, errObj["message"], "Method not found")
+}
+
+func TestInvalidJSON(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	rec := postJSON(srv, `{invalid json}`)
+
+	assert.Equal(t, http.StatusOK, rec.Code)
+
+	resp := parseResponse(t, rec)
+
+	errObj, ok := resp["error"].(map[string]any)
+	require.True(t, ok)
+
+	assert.Equal(t, float64(-32700), errObj["code"])
+}
+
+func TestBatchRejected(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	rec := postJSON(
+		srv,
+		`[{"jsonrpc":"2.0","method":"ping","id":1}]`,
+	)
+
+	assert.Equal(t, http.StatusOK, rec.Code)
+
+	resp := parseResponse(t, rec)
+
+	errObj, ok := resp["error"].(map[string]any)
+	require.True(t, ok)
+
+	assert.Equal(t, float64(-32600), errObj["code"])
+}
+
+func TestRequestBodyTooLarge(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	srv.SetMaxBodySize(10)
+
+	rec := postJSON(
+		srv,
+		`{"jsonrpc":"2.0","method":"ping","id":1}`,
+	)
+
+	assert.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// ---------- handler ----------
+
+func TestHandlerPanic(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	err := srv.RegisterMethods(MethodMap{
+		"panic.method": func(
+			ctx context.Context,
+			req *Request,
+		) *Response {
+			panic("boom")
+		},
+	})
+
+	require.NoError(t, err)
+
+	rec := postJSON(
+		srv,
+		`{"jsonrpc":"2.0","method":"panic.method","id":1}`,
+	)
+
+	assert.Equal(t, http.StatusOK, rec.Code)
+
+	resp := parseResponse(t, rec)
+
+	errObj, ok := resp["error"].(map[string]any)
+	require.True(t, ok)
+
+	assert.Equal(t, float64(-32603), errObj["code"])
+}
+
+func TestNilHandlerResponse(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	err := srv.RegisterMethods(MethodMap{
+		"nil.response": func(
+			ctx context.Context,
+			req *Request,
+		) *Response {
+			return nil
+		},
+	})
+
+	require.NoError(t, err)
+
+	rec := postJSON(
+		srv,
+		`{"jsonrpc":"2.0","method":"nil.response","id":1}`,
+	)
+
+	assert.Equal(t, http.StatusOK, rec.Code)
+
+	resp := parseResponse(t, rec)
+
+	assert.Nil(t, resp["result"])
+}
+
+// ---------- register ----------
+
+func TestRegisterMethodsAtomic(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	err := srv.RegisterMethods(MethodMap{
+		"a": func(
+			ctx context.Context,
+			req *Request,
+		) *Response {
+			return BuildResponse(req, "ok", nil)
+		},
+	})
+
+	require.NoError(t, err)
+
+	err = srv.RegisterMethods(MethodMap{
+		"b": nil, // invalid
+	})
+
+	assert.Error(t, err)
+	assert.NotContains(t, srv.ListMethods(), "b")
+}
+
+func TestListMethodsSorted(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	err := srv.RegisterMethods(MethodMap{
+		"z": func(
+			ctx context.Context,
+			req *Request,
+		) *Response {
+			return nil
+		},
+		"a": func(
+			ctx context.Context,
+			req *Request,
+		) *Response {
+			return nil
+		},
+	})
+
+	require.NoError(t, err)
+
+	methods := srv.ListMethods()
+
+	assert.Equal(
+		t,
+		[]string{"a", "ping", "z"},
+		methods,
+	)
+}
+
+// ---------- concurrency ----------
+
+func TestConcurrentRequests(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	err := srv.RegisterMethods(MethodMap{
+		"echo": func(
+			ctx context.Context,
+			req *Request,
+		) *Response {
+			time.Sleep(10 * time.Millisecond)
+			return BuildResponse(req, "ok", nil)
+		},
+	})
+
+	require.NoError(t, err)
+
+	const workers = 100
+
+	var wg sync.WaitGroup
+	errCh := make(chan error, workers)
+
+	for i := 0; i < workers; i++ {
+		wg.Add(1)
+
+		go func(id int) {
+			defer wg.Done()
+
+			body := fmt.Sprintf(
+				`{"jsonrpc":"2.0","method":"echo","id":%d}`,
+				id,
+			)
+
+			rec := postJSON(srv, body)
+
+			if rec.Code != http.StatusOK {
+				errCh <- fmt.Errorf(
+					"id=%d status=%d",
+					id,
+					rec.Code,
+				)
+			}
+		}(i)
+	}
+
+	wg.Wait()
+	close(errCh)
+
+	for err := range errCh {
+		t.Error(err)
+	}
+}
+
+// ---------- lifecycle ----------
+
+func TestStopCancelsHandlers(t *testing.T) {
+	srv := newTestRPCServer(t)
+
+	done := make(chan struct{})
+
+	err := srv.RegisterMethods(MethodMap{
+		"block": func(
+			ctx context.Context,
+			req *Request,
+		) *Response {
+			<-ctx.Done()
+			close(done)
+
+			return BuildResponse(
+				req,
+				"canceled",
+				nil,
+			)
+		},
+	})
+
+	require.NoError(t, err)
+
+	go func() {
+		postJSON(
+			srv,
+			`{"jsonrpc":"2.0","method":"block","id":1}`,
+		)
+	}()
+
+	time.Sleep(50 * time.Millisecond)
+
+	srv.Stop()
+
+	select {
+	case <-done:
+	case <-time.After(time.Second):
+		t.Fatal("handler was not canceled after Stop()")
+	}
+}