package jsonrpc2 import ( "context" "encoding/json" "io" "math" "net/http" "net/http/httptest" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // ---------- helpers ---------- func newTestHTTPServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { t.Helper() return httptest.NewServer(handler) } func echoHandler(t *testing.T) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) require.NoError(t, err) var req Request err = json.Unmarshal(body, &req) require.NoError(t, err) if req.IsNotification() { w.WriteHeader(http.StatusOK) return } resp := BuildResponse(&req, "ok", nil) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(resp) } } func errorHandler(code int) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(code) } } // ---------- tests ---------- func TestNewRPCClient(t *testing.T) { _, err := NewRPCClient("") assert.Error(t, err) _, err = NewRPCClient("http://[::1]:namedport") assert.Error(t, err) c, err := NewRPCClient("http://localhost") require.NoError(t, err) assert.NotNil(t, c) } func TestCallSuccess(t *testing.T) { srv := newTestHTTPServer(t, echoHandler(t)) defer srv.Close() c, err := NewRPCClient(srv.URL) require.NoError(t, err) resp, err := c.Call(context.Background(), "test", []int{1, 2}) require.NoError(t, err) var result string err = json.Unmarshal(resp.Result, &result) require.NoError(t, err) assert.Equal(t, "ok", result) } func TestCallTimeout(t *testing.T) { slowHandler := func(w http.ResponseWriter, r *http.Request) { time.Sleep(2 * time.Second) w.WriteHeader(http.StatusOK) } srv := newTestHTTPServer(t, slowHandler) defer srv.Close() c, err := NewRPCClient( srv.URL, WithHTTPClient(&http.Client{ Timeout: 500 * time.Millisecond, }), ) require.NoError(t, err) _, err = c.Call(context.Background(), "slow", nil) assert.Error(t, err) assert.ErrorIs(t, err, context.DeadlineExceeded) } func TestCallHTTPError(t *testing.T) { c, err := NewRPCClient("http://127.0.0.1:0") require.NoError(t, err) _, err = c.Call(context.Background(), "test", nil) assert.Error(t, err) } func TestCallNon200Status(t *testing.T) { srv := newTestHTTPServer(t, errorHandler(500)) defer srv.Close() c, err := NewRPCClient(srv.URL) require.NoError(t, err) _, err = c.Call(context.Background(), "test", nil) assert.ErrorContains(t, err, "unexpected status code") } func TestNotify(t *testing.T) { srv := newTestHTTPServer(t, echoHandler(t)) defer srv.Close() c, err := NewRPCClient(srv.URL) require.NoError(t, err) err = c.Notify(context.Background(), "notify", nil) assert.NoError(t, err) } func TestNotifyContextCancel(t *testing.T) { srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) { <-r.Context().Done() }) defer srv.Close() c, err := NewRPCClient(srv.URL) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) cancel() err = c.Notify(ctx, "notify", nil) assert.Error(t, err) } func TestCallContextCancel(t *testing.T) { srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) { <-r.Context().Done() }) defer srv.Close() c, err := NewRPCClient(srv.URL) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = c.Call(ctx, "test", nil) assert.Error(t, err) } func TestClientHeaders(t *testing.T) { var gotHeader http.Header srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) { gotHeader = r.Header.Clone() w.WriteHeader(http.StatusOK) }) defer srv.Close() c, err := NewRPCClient( srv.URL, SetHeader("Authorization", "Bearer token"), SetHeader("X-Request-ID", "123"), ) require.NoError(t, err) _, _ = c.Call(context.Background(), "test", nil) assert.Equal(t, "Bearer token", gotHeader.Get("Authorization")) assert.Equal(t, "123", gotHeader.Get("X-Request-ID")) assert.Equal(t, "application/json", gotHeader.Get("Content-Type")) } func TestNextIDWrapAround(t *testing.T) { c := &RPCClient{ seq: atomic.Int64{}, } c.seq.Store(math.MaxInt32) id1 := c.nextID() id2 := c.nextID() assert.Equal(t, int64(1), id1) assert.Equal(t, int64(2), id2) } func TestCustomHTTPClient(t *testing.T) { srv := newTestHTTPServer(t, echoHandler(t)) defer srv.Close() hc := &http.Client{ Timeout: 500 * time.Millisecond, } c, err := NewRPCClient( srv.URL, WithHTTPClient(hc), ) require.NoError(t, err) resp, err := c.Call(context.Background(), "test", nil) require.NoError(t, err) assert.NotNil(t, resp) }