client_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. package jsonrpc2
  2. import (
  3. "context"
  4. "encoding/json"
  5. "io"
  6. "math"
  7. "net/http"
  8. "net/http/httptest"
  9. "sync/atomic"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. )
  15. // ---------- helpers ----------
  16. func newTestHTTPServer(t *testing.T, handler http.HandlerFunc) *httptest.Server {
  17. t.Helper()
  18. return httptest.NewServer(handler)
  19. }
  20. func echoHandler(t *testing.T) http.HandlerFunc {
  21. return func(w http.ResponseWriter, r *http.Request) {
  22. body, err := io.ReadAll(r.Body)
  23. require.NoError(t, err)
  24. var req Request
  25. err = json.Unmarshal(body, &req)
  26. require.NoError(t, err)
  27. if req.IsNotification() {
  28. w.WriteHeader(http.StatusOK)
  29. return
  30. }
  31. resp := BuildResponse(&req, "ok", nil)
  32. w.Header().Set("Content-Type", "application/json")
  33. _ = json.NewEncoder(w).Encode(resp)
  34. }
  35. }
  36. func errorHandler(code int) http.HandlerFunc {
  37. return func(w http.ResponseWriter, r *http.Request) {
  38. w.WriteHeader(code)
  39. }
  40. }
  41. // ---------- tests ----------
  42. func TestNewRPCClient(t *testing.T) {
  43. _, err := NewRPCClient("")
  44. assert.Error(t, err)
  45. _, err = NewRPCClient("http://[::1]:namedport")
  46. assert.Error(t, err)
  47. c, err := NewRPCClient("http://localhost")
  48. require.NoError(t, err)
  49. assert.NotNil(t, c)
  50. }
  51. func TestCallSuccess(t *testing.T) {
  52. srv := newTestHTTPServer(t, echoHandler(t))
  53. defer srv.Close()
  54. c, err := NewRPCClient(srv.URL)
  55. require.NoError(t, err)
  56. resp, err := c.Call(context.Background(), "test", []int{1, 2})
  57. require.NoError(t, err)
  58. var result string
  59. err = json.Unmarshal(resp.Result, &result)
  60. require.NoError(t, err)
  61. assert.Equal(t, "ok", result)
  62. }
  63. func TestCallTimeout(t *testing.T) {
  64. slowHandler := func(w http.ResponseWriter, r *http.Request) {
  65. time.Sleep(2 * time.Second)
  66. w.WriteHeader(http.StatusOK)
  67. }
  68. srv := newTestHTTPServer(t, slowHandler)
  69. defer srv.Close()
  70. c, err := NewRPCClient(
  71. srv.URL,
  72. WithHTTPClient(&http.Client{
  73. Timeout: 500 * time.Millisecond,
  74. }),
  75. )
  76. require.NoError(t, err)
  77. _, err = c.Call(context.Background(), "slow", nil)
  78. assert.Error(t, err)
  79. assert.ErrorIs(t, err, context.DeadlineExceeded)
  80. }
  81. func TestCallHTTPError(t *testing.T) {
  82. c, err := NewRPCClient("http://127.0.0.1:0")
  83. require.NoError(t, err)
  84. _, err = c.Call(context.Background(), "test", nil)
  85. assert.Error(t, err)
  86. }
  87. func TestCallNon200Status(t *testing.T) {
  88. srv := newTestHTTPServer(t, errorHandler(500))
  89. defer srv.Close()
  90. c, err := NewRPCClient(srv.URL)
  91. require.NoError(t, err)
  92. _, err = c.Call(context.Background(), "test", nil)
  93. assert.ErrorContains(t, err, "unexpected status code")
  94. }
  95. func TestNotify(t *testing.T) {
  96. srv := newTestHTTPServer(t, echoHandler(t))
  97. defer srv.Close()
  98. c, err := NewRPCClient(srv.URL)
  99. require.NoError(t, err)
  100. err = c.Notify(context.Background(), "notify", nil)
  101. assert.NoError(t, err)
  102. }
  103. func TestNotifyContextCancel(t *testing.T) {
  104. srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
  105. <-r.Context().Done()
  106. })
  107. defer srv.Close()
  108. c, err := NewRPCClient(srv.URL)
  109. require.NoError(t, err)
  110. ctx, cancel := context.WithCancel(context.Background())
  111. cancel()
  112. err = c.Notify(ctx, "notify", nil)
  113. assert.Error(t, err)
  114. }
  115. func TestCallContextCancel(t *testing.T) {
  116. srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
  117. <-r.Context().Done()
  118. })
  119. defer srv.Close()
  120. c, err := NewRPCClient(srv.URL)
  121. require.NoError(t, err)
  122. ctx, cancel := context.WithCancel(context.Background())
  123. cancel()
  124. _, err = c.Call(ctx, "test", nil)
  125. assert.Error(t, err)
  126. }
  127. func TestClientHeaders(t *testing.T) {
  128. var gotHeader http.Header
  129. srv := newTestHTTPServer(t, func(w http.ResponseWriter, r *http.Request) {
  130. gotHeader = r.Header.Clone()
  131. w.WriteHeader(http.StatusOK)
  132. })
  133. defer srv.Close()
  134. c, err := NewRPCClient(
  135. srv.URL,
  136. SetHeader("Authorization", "Bearer token"),
  137. SetHeader("X-Request-ID", "123"),
  138. )
  139. require.NoError(t, err)
  140. _, _ = c.Call(context.Background(), "test", nil)
  141. assert.Equal(t, "Bearer token", gotHeader.Get("Authorization"))
  142. assert.Equal(t, "123", gotHeader.Get("X-Request-ID"))
  143. assert.Equal(t, "application/json", gotHeader.Get("Content-Type"))
  144. }
  145. func TestNextIDWrapAround(t *testing.T) {
  146. c := &RPCClient{
  147. seq: atomic.Int64{},
  148. }
  149. c.seq.Store(math.MaxInt32)
  150. id1 := c.nextID()
  151. id2 := c.nextID()
  152. assert.Equal(t, int64(1), id1)
  153. assert.Equal(t, int64(2), id2)
  154. }
  155. func TestCustomHTTPClient(t *testing.T) {
  156. srv := newTestHTTPServer(t, echoHandler(t))
  157. defer srv.Close()
  158. hc := &http.Client{
  159. Timeout: 500 * time.Millisecond,
  160. }
  161. c, err := NewRPCClient(
  162. srv.URL,
  163. WithHTTPClient(hc),
  164. )
  165. require.NoError(t, err)
  166. resp, err := c.Call(context.Background(), "test", nil)
  167. require.NoError(t, err)
  168. assert.NotNil(t, resp)
  169. }