execute.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. // Author: NiuJiuRu
  2. // Email: niujiuru@qq.com
  3. package shell
  4. import (
  5. "bytes"
  6. "context"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "os/exec"
  11. "syscall"
  12. "time"
  13. "github.com/mattn/go-shellwords"
  14. )
  15. const (
  16. DefaultTimeout = 5 * time.Second
  17. gracePeriod = 2 * time.Second
  18. forceKillWait = 2 * time.Second
  19. exitTimeoutCode = 124
  20. maxOutputSize = 1 << 20 // 最大 1 MB, 限制输出大小
  21. checkProcessDelay = 50 * time.Millisecond
  22. )
  23. var (
  24. ErrInvalidCommand = errors.New("invalid command")
  25. ErrExecutorLostControl = errors.New("executor lost control of process")
  26. )
  27. type ExecuteParams struct {
  28. Cmd string `json:"cmd"` // 命令
  29. Timeout int `json:"timeout,omitempty"` // 超时(秒)
  30. Dir string `json:"-"` // 工作目录
  31. }
  32. type ExecuteResult struct {
  33. Stdout string `json:"stdout"` ///////// 标准输出
  34. Stderr string `json:"stderr"` ///////// 错误输出
  35. ExitCode int `json:"exit_code"` ///////// 退出状态码: 0表示成功, 非0表示失败
  36. Cwd string `json:"cwd"` ///////// 当前目录
  37. }
  38. type limitedBuffer struct {
  39. buf *bytes.Buffer
  40. limit int
  41. }
  42. func (l *limitedBuffer) Write(p []byte) (int, error) {
  43. remain := l.limit - l.buf.Len()
  44. if remain <= 0 {
  45. return len(p), nil
  46. }
  47. if len(p) > remain {
  48. p = p[:remain]
  49. }
  50. return l.buf.Write(p)
  51. }
  52. type processGroup struct {
  53. cmd *exec.Cmd
  54. pgid int
  55. }
  56. // 进程组是否存在
  57. func (pg *processGroup) isProcessGroupAlive() bool {
  58. err := syscall.Kill(-pg.pgid, syscall.Signal(0))
  59. if err == nil {
  60. return true
  61. }
  62. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  63. return false
  64. }
  65. return true
  66. }
  67. // 等待进程组终止
  68. func (pg *processGroup) waitForTermination(timeout time.Duration) bool {
  69. deadline := time.Now().Add(timeout)
  70. for time.Now().Before(deadline) {
  71. if !pg.isProcessGroupAlive() {
  72. return true
  73. }
  74. time.Sleep(checkProcessDelay)
  75. }
  76. return !pg.isProcessGroupAlive()
  77. }
  78. // 终止整个进程组
  79. func (pg *processGroup) terminate() error {
  80. if pg.cmd.Process == nil || pg.pgid <= 0 || !pg.isProcessGroupAlive() {
  81. return nil
  82. }
  83. // 第一阶段: 尝试优雅终止, SIGTERM
  84. if err := syscall.Kill(-pg.pgid, syscall.SIGTERM); err != nil { // 如果发送信号失败,可能进程已经不存在
  85. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  86. return nil
  87. }
  88. }
  89. if pg.waitForTermination(gracePeriod) {
  90. return nil
  91. }
  92. // 第二阶段: 最后强制终止, SIGKILL
  93. if err := syscall.Kill(-pg.pgid, syscall.SIGKILL); err != nil { // 如果发送信号失败,可能进程已经不存在
  94. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  95. return nil
  96. }
  97. return fmt.Errorf("failed to send SIGKILL to process group: %w", err)
  98. }
  99. if pg.waitForTermination(forceKillWait) {
  100. return nil
  101. }
  102. return fmt.Errorf("%w: process group %d still alive after force kill",
  103. ErrExecutorLostControl, pg.pgid)
  104. }
  105. func executeInternal(p ExecuteParams, onStart func(pg *processGroup)) (*ExecuteResult, error) {
  106. swp := shellwords.NewParser()
  107. swp.ParseEnv = true // 展开 "环境变量"
  108. swp.ParseBacktick = true // 展开 `...`命令
  109. argv, err := swp.Parse(p.Cmd)
  110. if err != nil || len(argv) == 0 {
  111. return nil, ErrInvalidCommand
  112. }
  113. timeout := time.Duration(p.Timeout) * time.Second
  114. if timeout <= 0 {
  115. timeout = DefaultTimeout
  116. }
  117. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  118. defer cancel()
  119. cmd := exec.Command(argv[0], argv[1:]...)
  120. if p.Dir != "" { // 设置工作目录
  121. cmd.Dir = p.Dir
  122. }
  123. cmd.SysProcAttr = &syscall.SysProcAttr{ // 新的进程组
  124. Setpgid: true,
  125. Pdeathsig: syscall.SIGKILL,
  126. }
  127. var stdout, stderr bytes.Buffer
  128. cmd.Stdout = io.Writer(&limitedBuffer{buf: &stdout, limit: maxOutputSize})
  129. cmd.Stderr = io.Writer(&limitedBuffer{buf: &stderr, limit: maxOutputSize})
  130. if err := cmd.Start(); err != nil {
  131. return nil, err
  132. }
  133. processInfo := &processGroup{
  134. cmd: cmd,
  135. pgid: cmd.Process.Pid, // 进程组ID就是主进程的PID
  136. }
  137. if onStart != nil {
  138. onStart(processInfo)
  139. }
  140. done := make(chan error, 1)
  141. go func() {
  142. done <- cmd.Wait()
  143. }()
  144. exitCode := 0
  145. var finalErr error
  146. select {
  147. case err := <-done: // 命令已结束, 输出结果
  148. if err != nil {
  149. if ee, ok := err.(*exec.ExitError); ok {
  150. exitCode = ee.ExitCode()
  151. } else {
  152. return nil, err
  153. }
  154. }
  155. case <-ctx.Done(): /// 超时, kill整个进程组
  156. exitCode = exitTimeoutCode
  157. if err := processInfo.terminate(); err != nil {
  158. finalErr = err
  159. break
  160. }
  161. timer := time.NewTimer(forceKillWait)
  162. defer timer.Stop()
  163. select {
  164. case <-done:
  165. case <-timer.C:
  166. finalErr = ErrExecutorLostControl
  167. }
  168. }
  169. return &ExecuteResult{
  170. Stdout: stdout.String(),
  171. Stderr: stderr.String(),
  172. ExitCode: exitCode,
  173. }, finalErr
  174. }
  175. func Execute(p ExecuteParams) (*ExecuteResult, error) {
  176. return executeInternal(p, nil)
  177. }