execute.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package shell
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "os/exec"
  9. "syscall"
  10. "time"
  11. "github.com/mattn/go-shellwords"
  12. )
  13. const (
  14. defaultTimeout = 5 * time.Second
  15. gracePeriod = 2 * time.Second
  16. forceKillWait = 2 * time.Second
  17. exitTimeoutCode = 124
  18. maxOutputSize = 1 << 20 // 最大 1 MB, 限制输出大小
  19. checkProcessDelay = 50 * time.Millisecond
  20. )
  21. var (
  22. ErrInvalidCommand = errors.New("invalid command")
  23. ErrExecutorLostControl = errors.New("executor lost control of process")
  24. )
  25. type ExecuteParams struct {
  26. Cmd string `json:"cmd"` // 命令
  27. Timeout int `json:"timeout,omitempty"` // 超时(秒)
  28. }
  29. type ExecuteResult struct {
  30. Stdout string `json:"stdout"` ///////// 标准输出
  31. Stderr string `json:"stderr"` ///////// 错误输出
  32. ExitCode int `json:"exit_code"` ///////// 退出状态码: 0表示成功, 非0表示失败
  33. }
  34. type limitedBuffer struct {
  35. buf *bytes.Buffer
  36. limit int
  37. }
  38. func (l *limitedBuffer) Write(p []byte) (int, error) {
  39. remain := l.limit - l.buf.Len()
  40. if remain <= 0 {
  41. return len(p), nil
  42. }
  43. if len(p) > remain {
  44. p = p[:remain]
  45. }
  46. return l.buf.Write(p)
  47. }
  48. type processGroup struct {
  49. cmd *exec.Cmd
  50. pgid int
  51. }
  52. // 进程组是否存在
  53. func (pg *processGroup) isProcessGroupAlive() bool {
  54. err := syscall.Kill(-pg.pgid, syscall.Signal(0))
  55. if err == nil {
  56. return true
  57. }
  58. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  59. return false
  60. }
  61. return true
  62. }
  63. // 等待进程组终止
  64. func (pg *processGroup) waitForTermination(timeout time.Duration) bool {
  65. deadline := time.Now().Add(timeout)
  66. for time.Now().Before(deadline) {
  67. if !pg.isProcessGroupAlive() {
  68. return true
  69. }
  70. time.Sleep(checkProcessDelay)
  71. }
  72. return !pg.isProcessGroupAlive()
  73. }
  74. // 终止整个进程组
  75. func (pg *processGroup) terminate() error {
  76. if pg.cmd.Process == nil || pg.pgid <= 0 || !pg.isProcessGroupAlive() {
  77. return nil
  78. }
  79. // 第一阶段: 尝试优雅终止, SIGTERM
  80. if err := syscall.Kill(-pg.pgid, syscall.SIGTERM); err != nil { // 如果发送信号失败,可能进程已经不存在
  81. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  82. return nil
  83. }
  84. }
  85. if pg.waitForTermination(gracePeriod) {
  86. return nil
  87. }
  88. // 第二阶段: 最后强制终止, SIGKILL
  89. if err := syscall.Kill(-pg.pgid, syscall.SIGKILL); err != nil { // 如果发送信号失败,可能进程已经不存在
  90. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  91. return nil
  92. }
  93. return fmt.Errorf("failed to send SIGKILL to process group: %w", err)
  94. }
  95. if pg.waitForTermination(forceKillWait) {
  96. return nil
  97. }
  98. return fmt.Errorf("%w: process group %d still alive after force kill",
  99. ErrExecutorLostControl, pg.pgid)
  100. }
  101. func Execute(p ExecuteParams) (*ExecuteResult, error) {
  102. swp := shellwords.NewParser()
  103. swp.ParseEnv = true // 展开 "环境变量"
  104. swp.ParseBacktick = true // 展开 `...`命令
  105. argv, err := swp.Parse(p.Cmd)
  106. if err != nil || len(argv) == 0 {
  107. return nil, ErrInvalidCommand
  108. }
  109. timeout := time.Duration(p.Timeout) * time.Second
  110. if timeout <= 0 {
  111. timeout = defaultTimeout
  112. }
  113. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  114. defer cancel()
  115. cmd := exec.Command(argv[0], argv[1:]...)
  116. cmd.SysProcAttr = &syscall.SysProcAttr{ // 新的进程组
  117. Setpgid: true,
  118. Pdeathsig: syscall.SIGKILL,
  119. }
  120. var stdout, stderr bytes.Buffer
  121. cmd.Stdout = io.Writer(&limitedBuffer{buf: &stdout, limit: maxOutputSize})
  122. cmd.Stderr = io.Writer(&limitedBuffer{buf: &stderr, limit: maxOutputSize})
  123. if err := cmd.Start(); err != nil {
  124. return nil, err
  125. }
  126. processInfo := &processGroup{
  127. cmd: cmd,
  128. pgid: cmd.Process.Pid, // 进程组ID就是主进程的PID
  129. }
  130. done := make(chan error, 1)
  131. go func() {
  132. done <- cmd.Wait()
  133. }()
  134. exitCode := 0
  135. var finalErr error
  136. select {
  137. case err := <-done: // 命令已结束, 输出结果
  138. if err != nil {
  139. if ee, ok := err.(*exec.ExitError); ok {
  140. exitCode = ee.ExitCode()
  141. } else {
  142. return nil, err
  143. }
  144. }
  145. case <-ctx.Done(): /// 超时, kill整个进程组
  146. exitCode = exitTimeoutCode
  147. if err := processInfo.terminate(); err != nil {
  148. finalErr = err
  149. break
  150. }
  151. select {
  152. case <-done:
  153. case <-time.After(forceKillWait):
  154. finalErr = ErrExecutorLostControl
  155. }
  156. }
  157. return &ExecuteResult{
  158. Stdout: stdout.String(),
  159. Stderr: stderr.String(),
  160. ExitCode: exitCode,
  161. }, finalErr
  162. }