|
|
@@ -0,0 +1,198 @@
|
|
|
+package shell
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "context"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "os/exec"
|
|
|
+ "syscall"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/mattn/go-shellwords"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ defaultTimeout = 5 * time.Second
|
|
|
+ gracePeriod = 2 * time.Second
|
|
|
+ forceKillWait = 2 * time.Second
|
|
|
+ exitTimeoutCode = 124
|
|
|
+ maxOutputSize = 1 << 20 // 最大 1 MB, 限制输出大小
|
|
|
+ checkProcessDelay = 50 * time.Millisecond
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ ErrInvalidCommand = errors.New("invalid command")
|
|
|
+ ErrExecutorLostControl = errors.New("executor lost control of process")
|
|
|
+)
|
|
|
+
|
|
|
+type ExecuteParams struct {
|
|
|
+ Cmd string `json:"cmd"` // 命令
|
|
|
+ Timeout int `json:"timeout,omitempty"` // 超时(秒)
|
|
|
+}
|
|
|
+
|
|
|
+type ExecuteResult struct {
|
|
|
+ Stdout string `json:"stdout"` ///////// 标准输出
|
|
|
+ Stderr string `json:"stderr"` ///////// 错误输出
|
|
|
+ ExitCode int `json:"exit_code"` ///////// 退出状态码: 0表示成功, 非0表示失败
|
|
|
+}
|
|
|
+
|
|
|
+type limitedBuffer struct {
|
|
|
+ buf *bytes.Buffer
|
|
|
+ limit int
|
|
|
+}
|
|
|
+
|
|
|
+func (l *limitedBuffer) Write(p []byte) (int, error) {
|
|
|
+ remain := l.limit - l.buf.Len()
|
|
|
+ if remain <= 0 {
|
|
|
+ return len(p), nil
|
|
|
+ }
|
|
|
+ if len(p) > remain {
|
|
|
+ p = p[:remain]
|
|
|
+ }
|
|
|
+ return l.buf.Write(p)
|
|
|
+}
|
|
|
+
|
|
|
+type processGroup struct {
|
|
|
+ cmd *exec.Cmd
|
|
|
+ pgid int
|
|
|
+}
|
|
|
+
|
|
|
+// 进程组是否存在
|
|
|
+func (pg *processGroup) isProcessGroupAlive() bool {
|
|
|
+ err := syscall.Kill(-pg.pgid, syscall.Signal(0))
|
|
|
+ if err == nil {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+// 等待进程组终止
|
|
|
+func (pg *processGroup) waitForTermination(timeout time.Duration) bool {
|
|
|
+ deadline := time.Now().Add(timeout)
|
|
|
+
|
|
|
+ for time.Now().Before(deadline) {
|
|
|
+ if !pg.isProcessGroupAlive() {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ time.Sleep(checkProcessDelay)
|
|
|
+ }
|
|
|
+
|
|
|
+ return !pg.isProcessGroupAlive()
|
|
|
+}
|
|
|
+
|
|
|
+// 终止整个进程组
|
|
|
+func (pg *processGroup) terminate() error {
|
|
|
+ if pg.cmd.Process == nil || pg.pgid <= 0 || !pg.isProcessGroupAlive() {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // 第一阶段: 尝试优雅终止, SIGTERM
|
|
|
+ if err := syscall.Kill(-pg.pgid, syscall.SIGTERM); err != nil { // 如果发送信号失败,可能进程已经不存在
|
|
|
+ if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if pg.waitForTermination(gracePeriod) {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // 第二阶段: 最后强制终止, SIGKILL
|
|
|
+ if err := syscall.Kill(-pg.pgid, syscall.SIGKILL); err != nil { // 如果发送信号失败,可能进程已经不存在
|
|
|
+ if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return fmt.Errorf("failed to send SIGKILL to process group: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if pg.waitForTermination(forceKillWait) {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return fmt.Errorf("%w: process group %d still alive after force kill",
|
|
|
+ ErrExecutorLostControl, pg.pgid)
|
|
|
+}
|
|
|
+
|
|
|
+func Execute(p ExecuteParams) (*ExecuteResult, error) {
|
|
|
+ swp := shellwords.NewParser()
|
|
|
+ swp.ParseEnv = true // 展开 "环境变量"
|
|
|
+ swp.ParseBacktick = true // 展开 `...`命令
|
|
|
+
|
|
|
+ argv, err := swp.Parse(p.Cmd)
|
|
|
+ if err != nil || len(argv) == 0 {
|
|
|
+ return nil, ErrInvalidCommand
|
|
|
+ }
|
|
|
+
|
|
|
+ timeout := time.Duration(p.Timeout) * time.Second
|
|
|
+ if timeout <= 0 {
|
|
|
+ timeout = defaultTimeout
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ cmd := exec.Command(argv[0], argv[1:]...)
|
|
|
+
|
|
|
+ cmd.SysProcAttr = &syscall.SysProcAttr{ // 新的进程组
|
|
|
+ Setpgid: true,
|
|
|
+ Pdeathsig: syscall.SIGKILL,
|
|
|
+ }
|
|
|
+
|
|
|
+ var stdout, stderr bytes.Buffer
|
|
|
+ cmd.Stdout = io.Writer(&limitedBuffer{buf: &stdout, limit: maxOutputSize})
|
|
|
+ cmd.Stderr = io.Writer(&limitedBuffer{buf: &stderr, limit: maxOutputSize})
|
|
|
+
|
|
|
+ if err := cmd.Start(); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ processInfo := &processGroup{
|
|
|
+ cmd: cmd,
|
|
|
+ pgid: cmd.Process.Pid, // 进程组ID就是主进程的PID
|
|
|
+ }
|
|
|
+
|
|
|
+ done := make(chan error, 1)
|
|
|
+ go func() {
|
|
|
+ done <- cmd.Wait()
|
|
|
+ }()
|
|
|
+
|
|
|
+ exitCode := 0
|
|
|
+ var finalErr error
|
|
|
+
|
|
|
+ select {
|
|
|
+ case err := <-done: // 命令已结束, 输出结果
|
|
|
+ if err != nil {
|
|
|
+ if ee, ok := err.(*exec.ExitError); ok {
|
|
|
+ exitCode = ee.ExitCode()
|
|
|
+ } else {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case <-ctx.Done(): /// 超时, kill整个进程组
|
|
|
+ exitCode = exitTimeoutCode
|
|
|
+
|
|
|
+ if err := processInfo.terminate(); err != nil {
|
|
|
+ finalErr = err
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-done:
|
|
|
+ case <-time.After(forceKillWait):
|
|
|
+ finalErr = ErrExecutorLostControl
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return &ExecuteResult{
|
|
|
+ Stdout: stdout.String(),
|
|
|
+ Stderr: stderr.String(),
|
|
|
+ ExitCode: exitCode,
|
|
|
+ }, finalErr
|
|
|
+}
|