Parcourir la source

新增支持shell命令调用的模块

niujiuru il y a 1 semaine
Parent
commit
1570d1e948
4 fichiers modifiés avec 347 ajouts et 0 suppressions
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 198 0
      utils/shell/execute.go
  4. 146 0
      utils/shell/execute_test.go

+ 1 - 0
go.mod

@@ -6,6 +6,7 @@ require (
 	github.com/alexflint/go-filemutex v1.3.0
 	github.com/beevik/ntp v1.5.0
 	github.com/jlaffaye/ftp v0.2.0
+	github.com/mattn/go-shellwords v1.0.12
 	github.com/sirupsen/logrus v1.9.3
 	github.com/vishvananda/netlink v1.3.1
 	gopkg.in/ini.v1 v1.67.0

+ 2 - 0
go.sum

@@ -11,6 +11,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l
 github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
 github.com/jlaffaye/ftp v0.2.0 h1:lXNvW7cBu7R/68bknOX3MrRIIqZ61zELs1P2RAiA3lg=
 github.com/jlaffaye/ftp v0.2.0/go.mod h1:is2Ds5qkhceAPy2xD6RLI6hmp/qysSoymZ+Z2uTnspI=
+github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk=
+github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=

+ 198 - 0
utils/shell/execute.go

@@ -0,0 +1,198 @@
+package shell
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"os/exec"
+	"syscall"
+	"time"
+
+	"github.com/mattn/go-shellwords"
+)
+
+const (
+	defaultTimeout    = 5 * time.Second
+	gracePeriod       = 2 * time.Second
+	forceKillWait     = 2 * time.Second
+	exitTimeoutCode   = 124
+	maxOutputSize     = 1 << 20 // 最大 1 MB, 限制输出大小
+	checkProcessDelay = 50 * time.Millisecond
+)
+
+var (
+	ErrInvalidCommand      = errors.New("invalid command")
+	ErrExecutorLostControl = errors.New("executor lost control of process")
+)
+
+type ExecuteParams struct {
+	Cmd     string `json:"cmd"`               // 命令
+	Timeout int    `json:"timeout,omitempty"` // 超时(秒)
+}
+
+type ExecuteResult struct {
+	Stdout   string `json:"stdout"`    ///////// 标准输出
+	Stderr   string `json:"stderr"`    ///////// 错误输出
+	ExitCode int    `json:"exit_code"` ///////// 退出状态码: 0表示成功, 非0表示失败
+}
+
+type limitedBuffer struct {
+	buf   *bytes.Buffer
+	limit int
+}
+
+func (l *limitedBuffer) Write(p []byte) (int, error) {
+	remain := l.limit - l.buf.Len()
+	if remain <= 0 {
+		return len(p), nil
+	}
+	if len(p) > remain {
+		p = p[:remain]
+	}
+	return l.buf.Write(p)
+}
+
+type processGroup struct {
+	cmd  *exec.Cmd
+	pgid int
+}
+
+// 进程组是否存在
+func (pg *processGroup) isProcessGroupAlive() bool {
+	err := syscall.Kill(-pg.pgid, syscall.Signal(0))
+	if err == nil {
+		return true
+	}
+
+	if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
+		return false
+	}
+
+	return true
+}
+
+// 等待进程组终止
+func (pg *processGroup) waitForTermination(timeout time.Duration) bool {
+	deadline := time.Now().Add(timeout)
+
+	for time.Now().Before(deadline) {
+		if !pg.isProcessGroupAlive() {
+			return true
+		}
+		time.Sleep(checkProcessDelay)
+	}
+
+	return !pg.isProcessGroupAlive()
+}
+
+// 终止整个进程组
+func (pg *processGroup) terminate() error {
+	if pg.cmd.Process == nil || pg.pgid <= 0 || !pg.isProcessGroupAlive() {
+		return nil
+	}
+
+	// 第一阶段: 尝试优雅终止, SIGTERM
+	if err := syscall.Kill(-pg.pgid, syscall.SIGTERM); err != nil { // 如果发送信号失败,可能进程已经不存在
+		if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
+			return nil
+		}
+	}
+
+	if pg.waitForTermination(gracePeriod) {
+		return nil
+	}
+
+	// 第二阶段: 最后强制终止, SIGKILL
+	if err := syscall.Kill(-pg.pgid, syscall.SIGKILL); err != nil { // 如果发送信号失败,可能进程已经不存在
+		if errno, ok := err.(syscall.Errno); ok && errno == syscall.ESRCH {
+			return nil
+		}
+		return fmt.Errorf("failed to send SIGKILL to process group: %w", err)
+	}
+
+	if pg.waitForTermination(forceKillWait) {
+		return nil
+	}
+
+	return fmt.Errorf("%w: process group %d still alive after force kill",
+		ErrExecutorLostControl, pg.pgid)
+}
+
+func Execute(p ExecuteParams) (*ExecuteResult, error) {
+	swp := shellwords.NewParser()
+	swp.ParseEnv = true      // 展开 "环境变量"
+	swp.ParseBacktick = true // 展开 `...`命令
+
+	argv, err := swp.Parse(p.Cmd)
+	if err != nil || len(argv) == 0 {
+		return nil, ErrInvalidCommand
+	}
+
+	timeout := time.Duration(p.Timeout) * time.Second
+	if timeout <= 0 {
+		timeout = defaultTimeout
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+
+	cmd := exec.Command(argv[0], argv[1:]...)
+
+	cmd.SysProcAttr = &syscall.SysProcAttr{ // 新的进程组
+		Setpgid:   true,
+		Pdeathsig: syscall.SIGKILL,
+	}
+
+	var stdout, stderr bytes.Buffer
+	cmd.Stdout = io.Writer(&limitedBuffer{buf: &stdout, limit: maxOutputSize})
+	cmd.Stderr = io.Writer(&limitedBuffer{buf: &stderr, limit: maxOutputSize})
+
+	if err := cmd.Start(); err != nil {
+		return nil, err
+	}
+
+	processInfo := &processGroup{
+		cmd:  cmd,
+		pgid: cmd.Process.Pid, // 进程组ID就是主进程的PID
+	}
+
+	done := make(chan error, 1)
+	go func() {
+		done <- cmd.Wait()
+	}()
+
+	exitCode := 0
+	var finalErr error
+
+	select {
+	case err := <-done: // 命令已结束, 输出结果
+		if err != nil {
+			if ee, ok := err.(*exec.ExitError); ok {
+				exitCode = ee.ExitCode()
+			} else {
+				return nil, err
+			}
+		}
+	case <-ctx.Done(): /// 超时, kill整个进程组
+		exitCode = exitTimeoutCode
+
+		if err := processInfo.terminate(); err != nil {
+			finalErr = err
+			break
+		}
+
+		select {
+		case <-done:
+		case <-time.After(forceKillWait):
+			finalErr = ErrExecutorLostControl
+		}
+	}
+
+	return &ExecuteResult{
+		Stdout:   stdout.String(),
+		Stderr:   stderr.String(),
+		ExitCode: exitCode,
+	}, finalErr
+}

+ 146 - 0
utils/shell/execute_test.go

@@ -0,0 +1,146 @@
+package shell
+
+import (
+	"os"
+	"os/exec"
+	"strings"
+	"testing"
+	"time"
+)
+
+// 执行正常的命令
+func TestExecute_Success(t *testing.T) {
+	res, err := Execute(ExecuteParams{
+		Cmd:     `echo hello`,
+		Timeout: 2,
+	})
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+	if res.ExitCode != 0 {
+		t.Fatalf("expected exit code 0, got %d", res.ExitCode)
+	}
+	if strings.TrimSpace(res.Stdout) != "hello" {
+		t.Fatalf("unexpected stdout: %q", res.Stdout)
+	}
+}
+
+// 执行命令不存在
+func TestExecute_CommandNotFound(t *testing.T) {
+	_, err := Execute(ExecuteParams{
+		Cmd:     `this_command_should_not_exist_123`,
+		Timeout: 1,
+	})
+	if err == nil {
+		t.Fatalf("expected error, got nil")
+	}
+}
+
+// 命令非零退出码
+func TestExecute_NonZeroExit(t *testing.T) {
+	res, err := Execute(ExecuteParams{
+		Cmd:     `sh -c "exit 42"`,
+		Timeout: 2,
+	})
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+	if res.ExitCode != 42 {
+		t.Fatalf("expected exit code 42, got %d", res.ExitCode)
+	}
+}
+
+// 命令超时退出时
+func TestExecute_Timeout_KillProcessGroup(t *testing.T) {
+	start := time.Now()
+
+	res, err := Execute(ExecuteParams{
+		Cmd:     `sh -c "sleep 100 & wait"`,
+		Timeout: 1,
+	})
+
+	elapsed := time.Since(start)
+	if elapsed > 5*time.Second {
+		t.Fatalf("Execute hung too long: %v", elapsed)
+	}
+
+	if res.ExitCode != 124 {
+		t.Fatalf("expected exit code 124, got %d", res.ExitCode)
+	}
+
+	if err != nil && err != ErrExecutorLostControl {
+		t.Fatalf("unexpected error: %v", err)
+	}
+}
+
+// 命令输出超限时
+func TestExecute_OutputLimit(t *testing.T) {
+	res, err := Execute(ExecuteParams{
+		Cmd:     `sh -c "yes | head -c 2097152"`, // 2MB
+		Timeout: 2,
+	})
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+	if len(res.Stdout) > 1<<20 {
+		t.Fatalf("stdout exceeded limit: %d", len(res.Stdout))
+	}
+}
+
+// 测试无孤儿进程
+func TestExecute_NoOrphanProcess(t *testing.T) {
+	_, _ = Execute(ExecuteParams{
+		Cmd:     `sh -c "sleep 100 & wait"`,
+		Timeout: 1,
+	})
+
+	time.Sleep(300 * time.Millisecond)
+
+	out, _ := exec.Command("pgrep", "-f", "sleep 100").Output()
+	if len(out) > 0 {
+		t.Fatalf("orphan sleep process detected: %s", out)
+	}
+}
+
+// 执行正常的命令
+func TestExecute_Ls(t *testing.T) {
+	res, err := Execute(ExecuteParams{
+		Cmd:     "ls -l",
+		Timeout: 2,
+	})
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+
+	if res.ExitCode != 0 {
+		t.Fatalf("expected exit code 0, got %d", res.ExitCode)
+	}
+
+	if strings.TrimSpace(res.Stdout) == "" {
+		t.Fatalf("ls stdout should not be empty")
+	}
+}
+
+// 展开"环境变量"
+func TestExecute_WithEnvExpansion(t *testing.T) {
+	home := os.Getenv("HOME")
+	if home == "" {
+		t.Skip("HOME not set")
+	}
+
+	res, err := Execute(ExecuteParams{
+		Cmd:     `echo $HOME`,
+		Timeout: 2,
+	})
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+
+	if res.ExitCode != 0 {
+		t.Fatalf("expected exit code 0, got %d", res.ExitCode)
+	}
+
+	if strings.TrimSpace(res.Stdout) != home {
+		t.Fatalf("unexpected stdout: %q", res.Stdout)
+	}
+}