| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- // Author: NiuJiuRu
- // Email: niujiuru@qq.com
- package shell
- import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "io"
- "os/exec"
- "syscall"
- "time"
- "github.com/mattn/go-shellwords"
- )
- const (
- defaultTimeout = 5 * time.Second
- gracePeriod = 2 * time.Second
- forceKillWait = 2 * time.Second
- exitTimeoutCode = 124
- maxOutputSize = 1 << 20 // 最大 1 MB, 限制输出大小
- checkProcessDelay = 50 * time.Millisecond
- )
- var (
- ErrInvalidCommand = errors.New("invalid command")
- ErrExecutorLostControl = errors.New("executor lost control of process")
- )
- type ExecuteParams struct {
- Cmd string `json:"cmd"` // 命令
- Timeout int `json:"timeout,omitempty"` // 超时(秒)
- }
- type ExecuteResult struct {
- Stdout string `json:"stdout"` ///////// 标准输出
- Stderr string `json:"stderr"` ///////// 错误输出
- ExitCode int `json:"exit_code"` ///////// 退出状态码: 0表示成功, 非0表示失败
- }
- type limitedBuffer struct {
- buf *bytes.Buffer
- limit int
- }
- func (l *limitedBuffer) Write(p []byte) (int, error) {
- remain := l.limit - l.buf.Len()
- if remain <= 0 {
- return len(p), nil
- }
- if len(p) > remain {
- p = p[:remain]
- }
- return l.buf.Write(p)
- }
- type processGroup struct {
- cmd *exec.Cmd
- pgid int
- }
- // 进程组是否存在
- func (pg *processGroup) isProcessGroupAlive() bool {
- err := syscall.Kill(-pg.pgid, syscall.Signal(0))
- if err == nil {
- return true
- }
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
- return false
- }
- return true
- }
- // 等待进程组终止
- func (pg *processGroup) waitForTermination(timeout time.Duration) bool {
- deadline := time.Now().Add(timeout)
- for time.Now().Before(deadline) {
- if !pg.isProcessGroupAlive() {
- return true
- }
- time.Sleep(checkProcessDelay)
- }
- return !pg.isProcessGroupAlive()
- }
- // 终止整个进程组
- func (pg *processGroup) terminate() error {
- if pg.cmd.Process == nil || pg.pgid <= 0 || !pg.isProcessGroupAlive() {
- return nil
- }
- // 第一阶段: 尝试优雅终止, SIGTERM
- if err := syscall.Kill(-pg.pgid, syscall.SIGTERM); err != nil { // 如果发送信号失败,可能进程已经不存在
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
- return nil
- }
- }
- if pg.waitForTermination(gracePeriod) {
- return nil
- }
- // 第二阶段: 最后强制终止, SIGKILL
- if err := syscall.Kill(-pg.pgid, syscall.SIGKILL); err != nil { // 如果发送信号失败,可能进程已经不存在
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
- return nil
- }
- return fmt.Errorf("failed to send SIGKILL to process group: %w", err)
- }
- if pg.waitForTermination(forceKillWait) {
- return nil
- }
- return fmt.Errorf("%w: process group %d still alive after force kill",
- ErrExecutorLostControl, pg.pgid)
- }
- func Execute(p ExecuteParams) (*ExecuteResult, error) {
- swp := shellwords.NewParser()
- swp.ParseEnv = true // 展开 "环境变量"
- swp.ParseBacktick = true // 展开 `...`命令
- argv, err := swp.Parse(p.Cmd)
- if err != nil || len(argv) == 0 {
- return nil, ErrInvalidCommand
- }
- timeout := time.Duration(p.Timeout) * time.Second
- if timeout <= 0 {
- timeout = defaultTimeout
- }
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- cmd := exec.Command(argv[0], argv[1:]...)
- cmd.SysProcAttr = &syscall.SysProcAttr{ // 新的进程组
- Setpgid: true,
- Pdeathsig: syscall.SIGKILL,
- }
- var stdout, stderr bytes.Buffer
- cmd.Stdout = io.Writer(&limitedBuffer{buf: &stdout, limit: maxOutputSize})
- cmd.Stderr = io.Writer(&limitedBuffer{buf: &stderr, limit: maxOutputSize})
- if err := cmd.Start(); err != nil {
- return nil, err
- }
- processInfo := &processGroup{
- cmd: cmd,
- pgid: cmd.Process.Pid, // 进程组ID就是主进程的PID
- }
- done := make(chan error, 1)
- go func() {
- done <- cmd.Wait()
- }()
- exitCode := 0
- var finalErr error
- select {
- case err := <-done: // 命令已结束, 输出结果
- if err != nil {
- if ee, ok := err.(*exec.ExitError); ok {
- exitCode = ee.ExitCode()
- } else {
- return nil, err
- }
- }
- case <-ctx.Done(): /// 超时, kill整个进程组
- exitCode = exitTimeoutCode
- if err := processInfo.terminate(); err != nil {
- finalErr = err
- break
- }
- select {
- case <-done:
- case <-time.After(forceKillWait):
- finalErr = ErrExecutorLostControl
- }
- }
- return &ExecuteResult{
- Stdout: stdout.String(),
- Stderr: stderr.String(),
- ExitCode: exitCode,
- }, finalErr
- }
|