execute.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. //go:build !windows
  2. // +build !windows
  3. // Author: NiuJiuRu
  4. // Email: niujiuru@qq.com
  5. package shell
  6. import (
  7. "bytes"
  8. "context"
  9. "fmt"
  10. "io"
  11. "os/exec"
  12. "syscall"
  13. "time"
  14. "github.com/mattn/go-shellwords"
  15. )
  16. const (
  17. gracePeriod = 2 * time.Second
  18. forceKillWait = 2 * time.Second
  19. exitTimeoutCode = 124
  20. maxOutputSize = 1 << 20 // 最大 1 MB, 限制输出大小
  21. checkProcessDelay = 50 * time.Millisecond
  22. )
  23. type limitedBuffer struct {
  24. buf *bytes.Buffer
  25. limit int
  26. }
  27. func (l *limitedBuffer) Write(p []byte) (int, error) {
  28. remain := l.limit - l.buf.Len()
  29. if remain <= 0 {
  30. return len(p), nil
  31. }
  32. if len(p) > remain {
  33. p = p[:remain]
  34. }
  35. return l.buf.Write(p)
  36. }
  37. type processGroup struct {
  38. cmd *exec.Cmd
  39. pgid int
  40. }
  41. // 进程组是否存在
  42. func (pg *processGroup) isProcessGroupAlive() bool {
  43. err := syscall.Kill(-pg.pgid, syscall.Signal(0))
  44. if err == nil {
  45. return true
  46. }
  47. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  48. return false
  49. }
  50. return true
  51. }
  52. // 等待进程组终止
  53. func (pg *processGroup) waitForTermination(timeout time.Duration) bool {
  54. deadline := time.Now().Add(timeout)
  55. for time.Now().Before(deadline) {
  56. if !pg.isProcessGroupAlive() {
  57. return true
  58. }
  59. time.Sleep(checkProcessDelay)
  60. }
  61. return !pg.isProcessGroupAlive()
  62. }
  63. // 终止整个进程组
  64. func (pg *processGroup) terminate() error {
  65. if pg.cmd.Process == nil || pg.pgid <= 0 || !pg.isProcessGroupAlive() {
  66. return nil
  67. }
  68. // 第一阶段: 尝试优雅终止, SIGTERM
  69. if err := syscall.Kill(-pg.pgid, syscall.SIGTERM); err != nil { // 如果发送信号失败,可能进程已经不存在
  70. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  71. return nil
  72. }
  73. }
  74. if pg.waitForTermination(gracePeriod) {
  75. return nil
  76. }
  77. // 第二阶段: 最后强制终止, SIGKILL
  78. if err := syscall.Kill(-pg.pgid, syscall.SIGKILL); err != nil { // 如果发送信号失败,可能进程已经不存在
  79. if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
  80. return nil
  81. }
  82. return fmt.Errorf("failed to send SIGKILL to process group: %w", err)
  83. }
  84. if pg.waitForTermination(forceKillWait) {
  85. return nil
  86. }
  87. return fmt.Errorf("%w: process group %d still alive after force kill",
  88. ErrExecutorLostControl, pg.pgid)
  89. }
  90. func executeInternal(p ExecuteParams, onStart func(pg *processGroup)) (*ExecuteResult, error) {
  91. swp := shellwords.NewParser()
  92. swp.ParseEnv = true // 展开 "环境变量"
  93. swp.ParseBacktick = true // 展开 `...`命令
  94. argv, err := swp.Parse(p.Cmd)
  95. if err != nil || len(argv) == 0 {
  96. return nil, ErrInvalidCommand
  97. }
  98. timeout := time.Duration(p.Timeout) * time.Second
  99. if timeout <= 0 {
  100. timeout = DefaultTimeout
  101. }
  102. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  103. defer cancel()
  104. cmd := exec.Command(argv[0], argv[1:]...)
  105. if p.Dir != "" { // 设置工作目录
  106. cmd.Dir = p.Dir
  107. }
  108. cmd.SysProcAttr = &syscall.SysProcAttr{ // 新的进程组
  109. Setpgid: true,
  110. Pdeathsig: syscall.SIGKILL,
  111. }
  112. var stdout, stderr bytes.Buffer
  113. cmd.Stdout = io.Writer(&limitedBuffer{buf: &stdout, limit: maxOutputSize})
  114. cmd.Stderr = io.Writer(&limitedBuffer{buf: &stderr, limit: maxOutputSize})
  115. if err := cmd.Start(); err != nil {
  116. return nil, err
  117. }
  118. processInfo := &processGroup{
  119. cmd: cmd,
  120. pgid: cmd.Process.Pid, // 进程组ID就是主进程的PID
  121. }
  122. if onStart != nil {
  123. onStart(processInfo)
  124. }
  125. done := make(chan error, 1)
  126. go func() {
  127. done <- cmd.Wait()
  128. }()
  129. exitCode := 0
  130. var finalErr error
  131. select {
  132. case err := <-done: // 命令已结束, 输出结果
  133. if err != nil {
  134. if ee, ok := err.(*exec.ExitError); ok {
  135. exitCode = ee.ExitCode()
  136. } else {
  137. return nil, err
  138. }
  139. }
  140. case <-ctx.Done(): /// 超时, kill整个进程组
  141. exitCode = exitTimeoutCode
  142. if err := processInfo.terminate(); err != nil {
  143. finalErr = err
  144. break
  145. }
  146. timer := time.NewTimer(forceKillWait)
  147. defer timer.Stop()
  148. select {
  149. case <-done:
  150. case <-timer.C:
  151. finalErr = ErrExecutorLostControl
  152. }
  153. }
  154. return &ExecuteResult{
  155. Stdout: stdout.String(),
  156. Stderr: stderr.String(),
  157. ExitCode: exitCode,
  158. }, finalErr
  159. }
  160. func Execute(p ExecuteParams) (*ExecuteResult, error) {
  161. return executeInternal(p, nil)
  162. }