server_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. package jsonrpc2
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/http/httptest"
  10. "sync"
  11. "testing"
  12. "time"
  13. "github.com/sirupsen/logrus"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. )
  17. // ---------- helpers ----------
  18. func newTestRPCServer(t *testing.T) *RPCServer {
  19. t.Helper()
  20. logger := logrus.New()
  21. logger.SetOutput(io.Discard) // silence logs during tests
  22. srv, err := NewRPCServer("test-rpc", logger)
  23. require.NoError(t, err)
  24. srv.SetMaxBodySize(1024)
  25. return srv
  26. }
  27. func postJSON(h http.Handler, body string) *httptest.ResponseRecorder {
  28. req := httptest.NewRequest(
  29. http.MethodPost,
  30. "/rpc",
  31. bytes.NewBufferString(body),
  32. )
  33. req.Header.Set("Content-Type", "application/json")
  34. rec := httptest.NewRecorder()
  35. h.ServeHTTP(rec, req)
  36. return rec
  37. }
  38. func parseResponse(
  39. t *testing.T,
  40. rec *httptest.ResponseRecorder,
  41. ) map[string]any {
  42. t.Helper()
  43. var resp map[string]any
  44. err := json.Unmarshal(rec.Body.Bytes(), &resp)
  45. require.NoError(t, err)
  46. return resp
  47. }
  48. // ---------- basic ----------
  49. func TestPing(t *testing.T) {
  50. srv := newTestRPCServer(t)
  51. rec := postJSON(
  52. srv,
  53. `{"jsonrpc":"2.0","method":"ping","id":1}`,
  54. )
  55. assert.Equal(t, http.StatusOK, rec.Code)
  56. resp := parseResponse(t, rec)
  57. assert.Equal(t, "2.0", resp["jsonrpc"])
  58. assert.Equal(t, "pong", resp["result"])
  59. assert.Equal(t, float64(1), resp["id"])
  60. }
  61. func TestNotification(t *testing.T) {
  62. srv := newTestRPCServer(t)
  63. rec := postJSON(
  64. srv,
  65. `{"jsonrpc":"2.0","method":"ping"}`,
  66. )
  67. assert.Equal(t, http.StatusOK, rec.Code)
  68. assert.Empty(t, rec.Body.String())
  69. }
  70. func TestOnlyPOST(t *testing.T) {
  71. srv := newTestRPCServer(t)
  72. req := httptest.NewRequest(http.MethodGet, "/rpc", nil)
  73. rec := httptest.NewRecorder()
  74. srv.ServeHTTP(rec, req)
  75. assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
  76. }
  77. func TestMethodNotFound(t *testing.T) {
  78. srv := newTestRPCServer(t)
  79. rec := postJSON(
  80. srv,
  81. `{"jsonrpc":"2.0","method":"unknown","id":1}`,
  82. )
  83. assert.Equal(t, http.StatusOK, rec.Code)
  84. resp := parseResponse(t, rec)
  85. errObj, ok := resp["error"].(map[string]any)
  86. require.True(t, ok)
  87. assert.Equal(t, float64(-32601), errObj["code"])
  88. assert.Contains(t, errObj["message"], "Method not found")
  89. }
  90. func TestInvalidJSON(t *testing.T) {
  91. srv := newTestRPCServer(t)
  92. rec := postJSON(srv, `{invalid json}`)
  93. assert.Equal(t, http.StatusOK, rec.Code)
  94. resp := parseResponse(t, rec)
  95. errObj, ok := resp["error"].(map[string]any)
  96. require.True(t, ok)
  97. assert.Equal(t, float64(-32700), errObj["code"])
  98. }
  99. func TestBatchRejected(t *testing.T) {
  100. srv := newTestRPCServer(t)
  101. rec := postJSON(
  102. srv,
  103. `[{"jsonrpc":"2.0","method":"ping","id":1}]`,
  104. )
  105. assert.Equal(t, http.StatusOK, rec.Code)
  106. resp := parseResponse(t, rec)
  107. errObj, ok := resp["error"].(map[string]any)
  108. require.True(t, ok)
  109. assert.Equal(t, float64(-32600), errObj["code"])
  110. }
  111. func TestRequestBodyTooLarge(t *testing.T) {
  112. srv := newTestRPCServer(t)
  113. srv.SetMaxBodySize(10)
  114. rec := postJSON(
  115. srv,
  116. `{"jsonrpc":"2.0","method":"ping","id":1}`,
  117. )
  118. assert.Equal(t, http.StatusBadRequest, rec.Code)
  119. }
  120. // ---------- handler ----------
  121. func TestHandlerPanic(t *testing.T) {
  122. srv := newTestRPCServer(t)
  123. err := srv.RegisterMethods(MethodMap{
  124. "panic.method": func(
  125. ctx context.Context,
  126. req *Request,
  127. ) *Response {
  128. panic("boom")
  129. },
  130. })
  131. require.NoError(t, err)
  132. rec := postJSON(
  133. srv,
  134. `{"jsonrpc":"2.0","method":"panic.method","id":1}`,
  135. )
  136. assert.Equal(t, http.StatusOK, rec.Code)
  137. resp := parseResponse(t, rec)
  138. errObj, ok := resp["error"].(map[string]any)
  139. require.True(t, ok)
  140. assert.Equal(t, float64(-32603), errObj["code"])
  141. }
  142. func TestNilHandlerResponse(t *testing.T) {
  143. srv := newTestRPCServer(t)
  144. err := srv.RegisterMethods(MethodMap{
  145. "nil.response": func(
  146. ctx context.Context,
  147. req *Request,
  148. ) *Response {
  149. return nil
  150. },
  151. })
  152. require.NoError(t, err)
  153. rec := postJSON(
  154. srv,
  155. `{"jsonrpc":"2.0","method":"nil.response","id":1}`,
  156. )
  157. assert.Equal(t, http.StatusOK, rec.Code)
  158. resp := parseResponse(t, rec)
  159. assert.Nil(t, resp["result"])
  160. }
  161. // ---------- register ----------
  162. func TestRegisterMethodsAtomic(t *testing.T) {
  163. srv := newTestRPCServer(t)
  164. err := srv.RegisterMethods(MethodMap{
  165. "a": func(
  166. ctx context.Context,
  167. req *Request,
  168. ) *Response {
  169. return BuildResponse(req, "ok", nil)
  170. },
  171. })
  172. require.NoError(t, err)
  173. err = srv.RegisterMethods(MethodMap{
  174. "b": nil, // invalid
  175. })
  176. assert.Error(t, err)
  177. assert.NotContains(t, srv.ListMethods(), "b")
  178. }
  179. func TestListMethodsSorted(t *testing.T) {
  180. srv := newTestRPCServer(t)
  181. err := srv.RegisterMethods(MethodMap{
  182. "z": func(
  183. ctx context.Context,
  184. req *Request,
  185. ) *Response {
  186. return nil
  187. },
  188. "a": func(
  189. ctx context.Context,
  190. req *Request,
  191. ) *Response {
  192. return nil
  193. },
  194. })
  195. require.NoError(t, err)
  196. methods := srv.ListMethods()
  197. assert.Equal(
  198. t,
  199. []string{"a", "ping", "z"},
  200. methods,
  201. )
  202. }
  203. // ---------- concurrency ----------
  204. func TestConcurrentRequests(t *testing.T) {
  205. srv := newTestRPCServer(t)
  206. err := srv.RegisterMethods(MethodMap{
  207. "echo": func(
  208. ctx context.Context,
  209. req *Request,
  210. ) *Response {
  211. time.Sleep(10 * time.Millisecond)
  212. return BuildResponse(req, "ok", nil)
  213. },
  214. })
  215. require.NoError(t, err)
  216. const workers = 100
  217. var wg sync.WaitGroup
  218. errCh := make(chan error, workers)
  219. for i := range workers {
  220. wg.Add(1)
  221. go func(id int) {
  222. defer wg.Done()
  223. body := fmt.Sprintf(
  224. `{"jsonrpc":"2.0","method":"echo","id":%d}`,
  225. id,
  226. )
  227. rec := postJSON(srv, body)
  228. if rec.Code != http.StatusOK {
  229. errCh <- fmt.Errorf(
  230. "id=%d status=%d",
  231. id,
  232. rec.Code,
  233. )
  234. }
  235. }(i)
  236. }
  237. wg.Wait()
  238. close(errCh)
  239. for err := range errCh {
  240. t.Error(err)
  241. }
  242. }
  243. // ---------- lifecycle ----------
  244. func TestStopCancelsHandlers(t *testing.T) {
  245. srv := newTestRPCServer(t)
  246. done := make(chan struct{})
  247. err := srv.RegisterMethods(MethodMap{
  248. "block": func(
  249. ctx context.Context,
  250. req *Request,
  251. ) *Response {
  252. <-ctx.Done()
  253. close(done)
  254. return BuildResponse(
  255. req,
  256. "canceled",
  257. nil,
  258. )
  259. },
  260. })
  261. require.NoError(t, err)
  262. go func() {
  263. postJSON(
  264. srv,
  265. `{"jsonrpc":"2.0","method":"block","id":1}`,
  266. )
  267. }()
  268. time.Sleep(50 * time.Millisecond)
  269. srv.Stop()
  270. select {
  271. case <-done:
  272. case <-time.After(time.Second):
  273. t.Fatal("handler was not canceled after Stop()")
  274. }
  275. }