瀏覽代碼

优化远程运维模块sshd,使其支持客户端并发连接

niujiuru 2 周之前
父節點
當前提交
31522d7b29
共有 4 個文件被更改,包括 149 次插入24 次删除
  1. 23 2
      sshd/protocol.go
  2. 4 8
      sshd/readme.txt
  3. 120 12
      sshd/sshd.go
  4. 2 2
      utils/shell/executor.go

+ 23 - 2
sshd/protocol.go

@@ -3,6 +3,7 @@ package sshd
 import (
 	"encoding/json"
 	"fmt"
+	"strings"
 
 	"hnyfkj.com.cn/rtu/linux/utils/jsonrpc2"
 	"hnyfkj.com.cn/rtu/linux/utils/shell"
@@ -10,7 +11,7 @@ import (
 
 func buildResp(req *jsonrpc2.Request, result any, err error) *jsonrpc2.Response {
 	if err != nil {
-		return jsonrpc2.BuildError(req, -32700, err.Error())
+		return jsonrpc2.BuildError(req, jsonrpc2.ErrParse, err.Error())
 	}
 
 	resp, err := jsonrpc2.BuildResult(req, result)
@@ -21,7 +22,7 @@ func buildResp(req *jsonrpc2.Request, result any, err error) *jsonrpc2.Response
 	return resp
 }
 
-func parseShellExecuteParams(params json.RawMessage) (shell.ExecuteParams, error) {
+func extractShellExecuteParams(params json.RawMessage) (shell.ExecuteParams, error) {
 	if len(params) == 0 {
 		return shell.ExecuteParams{}, fmt.Errorf("missing params")
 	}
@@ -33,3 +34,23 @@ func parseShellExecuteParams(params json.RawMessage) (shell.ExecuteParams, error
 
 	return p, nil
 }
+
+func extractClientID(params json.RawMessage) (string, error) {
+	if len(params) == 0 {
+		return "", fmt.Errorf("missing params")
+	}
+
+	var p struct {
+		ClientID string `json:"client_id"`
+	}
+	if err := json.Unmarshal(params, &p); err != nil {
+		return "", err
+	}
+
+	clientID := strings.TrimSpace(p.ClientID)
+	if clientID == "" {
+		return "", fmt.Errorf("clientID is required and cannot be blank")
+	}
+
+	return clientID, nil
+}

+ 4 - 8
sshd/readme.txt

@@ -49,9 +49,10 @@
 │        (Web / App / CLI 运维平台)        │
 └───────────────┬──────────────────────────┘
                 │ JSON-RPC 2.0
-                │  - shell.execute
-                │  - shell.interrupt
-                │  - ping
+                |  - executor.ping
+                │  - executor.exec
+                │  - executor.interrupt
+                │  - executor.close
                 v
 ┌──────────────────────────────────────────┐
 │               MQTT Broker                │
@@ -69,11 +70,6 @@
 │  - JSON-RPC 解析 / 校验                  │
 │  - 方法分发 (Method Dispatch)            │
 │                                          │
-│  RPC → 本地语义映射                       │
-│  --------------------------------------- │
-│  shell.execute   → Executor.Exec()       │
-│  shell.interrupt → Executor.Interrupt()  │
-│                                          │
 └───────────────┬──────────────────────────┘
                 │ 串行 / 单 Session
                 v

+ 120 - 12
sshd/sshd.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"strings"
+	"sync"
 	"sync/atomic"
 	"time"
 
@@ -25,6 +26,9 @@ const (
 	MqttQos1     byte = 1               //// 消息至少送达一次
 	FastInterval      = 1 * time.Second //// 快速检测时间间隔
 	SlowInterval      = 5 * time.Second //// 慢速检测时间间隔
+
+	ExecutorCheckInterval = 2 * time.Second // 执行器回收检测
+	ExecutorTimeout       = 6 * time.Second // 执行器超时时间
 )
 
 var (
@@ -43,7 +47,9 @@ type MQTTCoupler struct {
 	ctx    context.Context
 	cancel context.CancelFunc
 
-	executor *shell.Executor // 本地执行器, 单实例-串行执行指令
+	///////// 本地执行器, 允许多客户端, 同一客户端串行的执行指令
+	executorMap   map[string]*clientExecutor
+	executorMapMu sync.Mutex
 
 	isConnected atomic.Bool /// 标记是否已连接MQTT的Broker服务
 
@@ -51,6 +57,21 @@ type MQTTCoupler struct {
 	registerRpcMeths *singletask.OnceTask // 注册方法, 单实例
 }
 
+type executorState int
+
+const (
+	execIdle    executorState = iota // 空闲状态时, 可安全回收
+	execRunning                      // 正在执行时, 不允许回收
+)
+
+type clientExecutor struct {
+	id       string
+	executor *shell.Executor
+	mu       sync.Mutex    ///////////////////// 同ID串行执行
+	lastPing time.Time     ///////////////////// 用于超时回收
+	state    executorState ///////////////////// 执行器的状态
+}
+
 func ModuleInit(mqttBroker, mqttUsername, mqttPassword string) bool {
 	if mqttBroker == "" {
 		baseapp.Logger.Errorf("[%s] 初始化远程运维模块失败: %v!!", MODULE_NAME, ErrBrokerAddressEmpty)
@@ -69,15 +90,16 @@ func ModuleInit(mqttBroker, mqttUsername, mqttPassword string) bool {
 		pubTopic:         "",
 		ctx:              ctx,
 		cancel:           cancel,
-		executor:         shell.NewExecutor(),
+		executorMap:      make(map[string]*clientExecutor),
 		isConnected:      atomic.Bool{},
 		registerRpcMeths: &singletask.OnceTask{},
 	}
 
-	if err := Coupler.init(); err != nil {
+	if err := Coupler.init2(); err != nil {
 		baseapp.Logger.Errorf("[%s] 初始化远程运维模块失败: %v!!", MODULE_NAME, err)
 		return false
 	}
+	go Coupler.startExecutorReaper(ExecutorCheckInterval, ExecutorTimeout)
 	go Coupler.keepOnline()
 
 	return true
@@ -89,7 +111,7 @@ func ModuleExit() {
 	}
 }
 
-func (c *MQTTCoupler) init() error {
+func (c *MQTTCoupler) init2() error {
 	c.imei = netmgrd.GetIMEI()
 	if c.imei == netmgrd.ErrUnknownModemTypeMsg || c.imei == "" {
 		return ErrIMEINotAvailable
@@ -197,7 +219,7 @@ func (c *MQTTCoupler) instRPCMethods() {
 }
 
 func (c *MQTTCoupler) handleRequests(client mqtt.Client, msg mqtt.Message) {
-	c.execOneCmd(msg)
+	go c.execOneCmd(msg)
 }
 
 func (c *MQTTCoupler) execOneCmd(msg mqtt.Message) {
@@ -205,6 +227,9 @@ func (c *MQTTCoupler) execOneCmd(msg mqtt.Message) {
 	baseapp.Logger.Infof("[%s] 收到一个RPC请求: %s", MODULE_NAME, str)
 
 	var resp *jsonrpc2.Response // 预定义一个空的应答
+	var clientID string         // 该客户端的唯一标识
+	var ce *clientExecutor      // 该客户端的指令执行器
+	var exists bool             // 判断执行器是否已存在
 
 	req, err := jsonrpc2.ParseRequest(str)
 	if err != nil || req.ID == nil /* 不接受通知类型的消息 */ {
@@ -212,28 +237,87 @@ func (c *MQTTCoupler) execOneCmd(msg mqtt.Message) {
 		goto retp
 	}
 
+	clientID, err = extractClientID(req.Params)
+	if err != nil {
+		resp = jsonrpc2.BuildError(req, jsonrpc2.ErrInvalidParams, err.Error())
+		goto retp
+	}
+
+	c.executorMapMu.Lock()
+	ce, exists = c.executorMap[clientID]
+	if !exists {
+		if len(c.executorMap) >= 3 {
+			c.executorMapMu.Unlock()
+			resp = jsonrpc2.BuildError(req, -32000, "connection refused: server has reached maximum client capacity (3/3)")
+			goto retp
+		}
+		ce = &clientExecutor{
+			id:       clientID,
+			executor: shell.NewExecutor(),
+			state:    execIdle,
+		}
+		c.executorMap[clientID] = ce
+	}
+	c.executorMapMu.Unlock()
+
+	ce.mu.Lock() // 确保同一客户端(ID一样)的指令串行执行
+	ce.lastPing = time.Now()
+
 	switch req.Method {
 	// Call-1: 心跳, 链路检测,"ping-pong"测试
-	case "ping":
+	case "executor.ping":
 		resp = buildResp(req, "pong", nil)
 	// Call-2:在本地shell中执行远程下发的指令
-	case "shell.execute":
-		params, err := parseShellExecuteParams(req.Params)
+	case "executor.exec":
+		params, err := extractShellExecuteParams(req.Params)
 		if err != nil {
-			resp = jsonrpc2.BuildError(req, -32700, err.Error())
+			ce.mu.Unlock()
+			resp = jsonrpc2.BuildError(req, jsonrpc2.ErrParse, err.Error())
 			goto retp
 		}
-		result, err := c.executor.Exec(params)
+
+		ce.state = execRunning
+		ce.mu.Unlock()
+
+		result, err := ce.executor.Exec(params)
+
+		ce.mu.Lock()
+		ce.state = execIdle
+		ce.lastPing = time.Now()
+		ce.mu.Unlock()
+
 		resp = buildResp(req, result, err)
+		goto retp
 	// Call-3:中断本地shell的执行,等价Ctrl+C
-	case "shell.interrupt":
-		err := c.executor.Interrupt()
+	case "executor.interrupt":
+		if ce.state != execRunning {
+			resp = jsonrpc2.BuildError(req, -32001, "no running command")
+			break
+		}
+		err := ce.executor.Interrupt()
 		resp = buildResp(req, "interrupted", err)
+	// Call-4:客户端安全退出, 释放本地的执行器
+	case "executor.close":
+		if ce.state == execRunning {
+			ce.mu.Unlock()
+			resp = jsonrpc2.BuildError(req, -32002, "executor busy, interrupt first")
+			goto retp
+		}
+		ce.mu.Unlock()
+
+		c.executorMapMu.Lock()
+		delete(c.executorMap, clientID)
+		c.executorMapMu.Unlock()
+
+		resp = buildResp(req, "bye", nil)
+		goto retp
 	// Call-?:无效, 远端调用了还不支持的-方法
 	default:
 		resp = jsonrpc2.BuildError(req, jsonrpc2.ErrMethodNotFound, "")
 	}
 
+	ce.mu.Unlock()
+
 retp:
 	text, err := resp.String()
 	if err != nil {
@@ -254,3 +338,27 @@ retp:
 
 	baseapp.Logger.Infof("[%s] 发送一个RPC应答, 报文内容: %s", MODULE_NAME, text)
 }
+
+func (c *MQTTCoupler) startExecutorReaper(interval, timeout time.Duration) {
+	ticker := time.NewTicker(interval)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-c.ctx.Done():
+			return
+		case <-ticker.C:
+			c.executorMapMu.Lock()
+			for id, ce := range c.executorMap {
+				ce.mu.Lock()
+				expired := time.Since(ce.lastPing) > timeout
+				idle := (ce.state == execIdle)
+				ce.mu.Unlock()
+
+				if expired && idle {
+					delete(c.executorMap, id)
+				}
+			} // end for
+			c.executorMapMu.Unlock()
+		} // end select
+	} // end for
+}

+ 2 - 2
utils/shell/executor.go

@@ -7,8 +7,8 @@ import (
 )
 
 type Executor struct {
-	cwd string        // 当前所在目录
-	pg  *processGroup // 当前执行进程
+	cwd string        // 当前目录
+	pg  *processGroup // --进程组
 }
 
 func NewExecutor() *Executor {