| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- package jsonrpc2
- import (
- "context"
- "encoding/json"
- "io"
- "math"
- "net/http"
- "net/http/httptest"
- "sync/atomic"
- "testing"
- "time"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- // ---------- helpers ----------
- func newTestHTTPServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
- t.Helper()
- return httptest.NewServer(handler)
- }
- func echoHandler(t *testing.T) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- body, err := io.ReadAll(r.Body)
- require.NoError(t, err)
- var req Request
- err = json.Unmarshal(body, &req)
- require.NoError(t, err)
- if req.IsNotification() {
- w.WriteHeader(http.StatusOK)
- return
- }
- resp := BuildResponse(&req, "ok", nil)
- w.Header().Set("Content-Type", "application/json")
- _ = json.NewEncoder(w).Encode(resp)
- }
- }
- func errorHandler(code int) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(code)
- }
- }
- // ---------- tests ----------
- func TestNewRPCClient(t *testing.T) {
- _, err := NewRPCClient("")
- assert.Error(t, err)
- _, err = NewRPCClient("http://[::1]:namedport")
- assert.Error(t, err)
- c, err := NewRPCClient("http://localhost")
- require.NoError(t, err)
- assert.NotNil(t, c)
- }
- func TestCallSuccess(t *testing.T) {
- srv := newTestHTTPServer(t, echoHandler(t))
- defer srv.Close()
- c, err := NewRPCClient(srv.URL)
- require.NoError(t, err)
- resp, err := c.Call(context.Background(), "test", []int{1, 2})
- require.NoError(t, err)
- var result string
- err = json.Unmarshal(resp.Result, &result)
- require.NoError(t, err)
- assert.Equal(t, "ok", result)
- }
- func TestCallTimeout(t *testing.T) {
- slowHandler := func(w http.ResponseWriter, r *http.Request) {
- time.Sleep(2 * time.Second)
- w.WriteHeader(http.StatusOK)
- }
- srv := newTestHTTPServer(t, slowHandler)
- defer srv.Close()
- c, err := NewRPCClient(
- srv.URL,
- WithHTTPClient(&http.Client{
- Timeout: 500 * time.Millisecond,
- }),
- )
- require.NoError(t, err)
- _, err = c.Call(context.Background(), "slow", nil)
- assert.Error(t, err)
- assert.ErrorIs(t, err, context.DeadlineExceeded)
- }
- func TestCallHTTPError(t *testing.T) {
- c, err := NewRPCClient("http://127.0.0.1:0")
- require.NoError(t, err)
- _, err = c.Call(context.Background(), "test", nil)
- assert.Error(t, err)
- }
- func TestCallNon200Status(t *testing.T) {
- srv := newTestHTTPServer(t, errorHandler(500))
- defer srv.Close()
- c, err := NewRPCClient(srv.URL)
- require.NoError(t, err)
- _, err = c.Call(context.Background(), "test", nil)
- assert.ErrorContains(t, err, "unexpected status code")
- }
- func TestNotify(t *testing.T) {
- srv := newTestHTTPServer(t, echoHandler(t))
- defer srv.Close()
- c, err := NewRPCClient(srv.URL)
- require.NoError(t, err)
- err = c.Notify(context.Background(), "notify", nil)
- assert.NoError(t, err)
- }
- func TestNotifyContextCancel(t *testing.T) {
- srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
- <-r.Context().Done()
- })
- defer srv.Close()
- c, err := NewRPCClient(srv.URL)
- require.NoError(t, err)
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- err = c.Notify(ctx, "notify", nil)
- assert.Error(t, err)
- }
- func TestCallContextCancel(t *testing.T) {
- srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
- <-r.Context().Done()
- })
- defer srv.Close()
- c, err := NewRPCClient(srv.URL)
- require.NoError(t, err)
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- _, err = c.Call(ctx, "test", nil)
- assert.Error(t, err)
- }
- func TestClientHeaders(t *testing.T) {
- var gotHeader http.Header
- srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
- gotHeader = r.Header.Clone()
- w.WriteHeader(http.StatusOK)
- })
- defer srv.Close()
- c, err := NewRPCClient(
- srv.URL,
- SetHeader("Authorization", "Bearer token"),
- SetHeader("X-Request-ID", "123"),
- )
- require.NoError(t, err)
- _, _ = c.Call(context.Background(), "test", nil)
- assert.Equal(t, "Bearer token", gotHeader.Get("Authorization"))
- assert.Equal(t, "123", gotHeader.Get("X-Request-ID"))
- assert.Equal(t, "application/json", gotHeader.Get("Content-Type"))
- }
- func TestNextIDWrapAround(t *testing.T) {
- c := &RPCClient{
- seq: atomic.Int64{},
- }
- c.seq.Store(math.MaxInt32)
- id1 := c.nextID()
- id2 := c.nextID()
- assert.Equal(t, int64(1), id1)
- assert.Equal(t, int64(2), id2)
- }
- func TestCustomHTTPClient(t *testing.T) {
- srv := newTestHTTPServer(t, echoHandler(t))
- defer srv.Close()
- hc := &http.Client{
- Timeout: 500 * time.Millisecond,
- }
- c, err := NewRPCClient(
- srv.URL,
- WithHTTPClient(hc),
- )
- require.NoError(t, err)
- resp, err := c.Call(context.Background(), "test", nil)
- require.NoError(t, err)
- assert.NotNil(t, resp)
- }
|