Quellcode durchsuchen

继续编写sshd远程运维模块的client代码

niujiuru vor 2 Wochen
Ursprung
Commit
27be5526c0
5 geänderte Dateien mit 193 neuen und 16 gelöschten Zeilen
  1. 1 1
      sshd/client/client.go
  2. 1 1
      sshd/client/config.go
  3. 122 6
      sshd/client/coupler.go
  4. 64 5
      sshd/client/invoker.go
  5. 5 3
      utils/shell/execute.go

+ 1 - 1
sshd/client/client.go

@@ -13,7 +13,7 @@ const MODULE_NAME = "YFKJ_SSH_CLIENT"
 
 var (
 	coupler               *MQTTCoupler
-	Version               = "0.0.0.1"
+	Version               = "1.0.0.1"
 	ErrBrokerAddressEmpty = errors.New("mqtt server address is empty")
 	ErrIMEINotAvailable   = errors.New("device imei is not available")
 )

+ 1 - 1
sshd/client/config.go

@@ -63,7 +63,7 @@ func loadAppConfig() error {
 	return nil
 }
 
-func GetCmdTimeoutByPrefix(cmd string) int {
+func getCmdTimeoutByPrefix(cmd string) int {
 	if cmd == "" || len(CfgServers.Cmds) == 0 {
 		return -1
 	}

+ 122 - 6
sshd/client/coupler.go

@@ -2,6 +2,7 @@ package main
 
 import (
 	"context"
+	"encoding/json"
 	"fmt"
 	"strings"
 	"sync"
@@ -9,6 +10,8 @@ import (
 	"time"
 
 	mqtt "github.com/eclipse/paho.mqtt.golang"
+	"hnyfkj.com.cn/rtu/linux/utils/jsonrpc2"
+	"hnyfkj.com.cn/rtu/linux/utils/shell"
 )
 
 const (
@@ -27,11 +30,14 @@ type MQTTCoupler struct {
 	clientID    string
 	isConnected atomic.Bool /////// 标记是否已连接MQTT的Broker服务
 
-	imei     string     // 设备唯一标识
-	subTopic string     // 订阅应答主题:/yfkj/device/rpc/imei/ack
-	pubTopic string     // 发布指令主题:/yfkj/device/rpc/imei/cmd
-	cwd      string     // 当前工作目录
-	mu       sync.Mutex // 串行执行的锁
+	imei     string // 设备唯一标识
+	subTopic string // 订阅应答主题:/yfkj/device/rpc/imei/ack
+	pubTopic string // 发布指令主题:/yfkj/device/rpc/imei/cmd
+	cwd      string // 当前工作目录
+
+	cmdMu     sync.Mutex                       // 串行执行的锁
+	pending   map[int]chan shell.ExecuteResult // 等待命令结果
+	pendingMu sync.Mutex                       // 等待结果的锁
 }
 
 func (c *MQTTCoupler) init() error {
@@ -50,6 +56,17 @@ func (c *MQTTCoupler) init() error {
 		if !c.isConnected.Swap(true) {
 			fmt.Printf("[%s] MQTT Broker连接成功", MODULE_NAME)
 		}
+		go func() { // 订阅应答主题
+			token := c.client.Subscribe(c.subTopic, MqttQos1, c.onCmdAck)
+			select {
+			case <-c.ctx.Done():
+				return
+			case <-token.Done():
+			}
+			if token.Error() != nil {
+				return
+			}
+		}()
 	}
 
 	opts.OnConnectionLost = func(client mqtt.Client, err error) {
@@ -58,8 +75,10 @@ func (c *MQTTCoupler) init() error {
 		}
 	}
 
+	c.pending = make(map[int]chan shell.ExecuteResult)
+
 	c.client = mqtt.NewClient(opts)
-	go coupler.keepOnline()
+	go c.keepOnline()
 
 	return nil
 }
@@ -104,3 +123,100 @@ func (c *MQTTCoupler) connect() error {
 
 	return token.Error()
 }
+
+func (c *MQTTCoupler) doCmd(method string, params any, id ...int) (shell.ExecuteResult, error) {
+	if c.needSerialize(method) {
+		c.cmdMu.Lock()
+		defer c.cmdMu.Unlock()
+	}
+
+	req, err := jsonrpc2.BuildRequest(method, params, id...)
+	if err != nil {
+		return shell.ExecuteResult{}, err
+	}
+	reqID := *req.ID
+
+	b, err := json.Marshal(req)
+	if err != nil {
+		return shell.ExecuteResult{}, err
+	}
+
+	ch := make(chan shell.ExecuteResult, 1)
+
+	c.pendingMu.Lock()
+	c.pending[reqID] = ch
+	c.pendingMu.Unlock()
+	defer func() {
+		c.pendingMu.Lock()
+		delete(c.pending, reqID)
+		c.pendingMu.Unlock()
+	}()
+
+	token := c.client.Publish(c.pubTopic, MqttQos1, false, b)
+
+	select {
+	case <-c.ctx.Done():
+		return shell.ExecuteResult{}, c.ctx.Err()
+	case <-token.Done():
+	}
+
+	if token.Error() != nil {
+		return shell.ExecuteResult{}, token.Error()
+	}
+
+	var timer *time.Timer
+	var timeout <-chan time.Time
+	if c.needTimeoutEnd(method) {
+		timer = time.NewTimer(shell.DefaultTimeout)
+		timeout = timer.C
+		defer timer.Stop()
+	}
+
+	select {
+	case <-c.ctx.Done():
+		return shell.ExecuteResult{}, c.ctx.Err()
+	case res := <-ch:
+		return res, nil
+	case <-timeout:
+		return shell.ExecuteResult{}, fmt.Errorf("command timeout")
+	}
+}
+
+func (c *MQTTCoupler) onCmdAck(client mqtt.Client, msg mqtt.Message) {
+	p := msg.Payload()
+
+	var resp jsonrpc2.Response
+	if err := json.Unmarshal(p, &resp); err != nil {
+		return
+	}
+
+	if resp.ID == nil { // 通知类消息, 设计上不应该出现
+		return
+	}
+	respID := *resp.ID
+
+	c.pendingMu.Lock()
+	ch, ok := c.pending[respID]
+	c.pendingMu.Unlock()
+
+	if !ok { /////////////// 未找到对应的请求, 忽略不管
+		return
+	}
+
+	var execResult shell.ExecuteResult
+
+	if resp.Error != nil { ////////////////// 错误应答
+		execResult.ExitCode = int(resp.Error.Code)
+		execResult.Stderr = resp.Error.Message
+	} else if len(resp.Result) > 0 { //////// 正确应答
+		if err := json.Unmarshal(resp.Result, &execResult); err != nil {
+			execResult.ExitCode = 1
+			execResult.Stderr = err.Error()
+		}
+	}
+
+	select {
+	case ch <- execResult:
+	default:
+	}
+}

+ 64 - 5
sshd/client/invoker.go

@@ -1,24 +1,83 @@
 package main
 
-import "hnyfkj.com.cn/rtu/linux/utils/shell"
+import (
+	"time"
+
+	"hnyfkj.com.cn/rtu/linux/utils/shell"
+)
+
+var (
+	rpc_ping = "executor.ping"
+	rpc_exec = "executor.exec"
+	rpc_stop = "executor.interrupt"
+	rpc_quit = "executor.close"
+)
+
+func (c *MQTTCoupler) needSerialize(method string) bool {
+	if method == rpc_ping || method == rpc_stop {
+		return false
+	}
+	return true
+}
+
+func (c *MQTTCoupler) needTimeoutEnd(method string) bool {
+	if method == rpc_ping || method == rpc_stop || method == rpc_quit {
+		return true
+	}
+	return false
+}
 
 // 心跳检测
 func (c *MQTTCoupler) Ping() (shell.ExecuteResult, error) {
-	return shell.ExecuteResult{}, nil
+	params := struct {
+		ClientID string `json:"client_id"`
+	}{
+		ClientID: c.clientID,
+	}
+
+	return c.doCmd(rpc_ping, params)
 }
 
 // 执行命令
 func (c *MQTTCoupler) Exec(
 	cmd string) (shell.ExecuteResult, error) {
-	return shell.ExecuteResult{}, nil
+	params := struct {
+		ClientID string `json:"client_id"`
+		shell.ExecuteParams
+	}{
+		ClientID: c.clientID,
+		ExecuteParams: shell.ExecuteParams{
+			Cmd:     cmd,
+			Timeout: int(shell.DefaultTimeout / time.Second),
+		},
+	}
+
+	timeout := getCmdTimeoutByPrefix(cmd)
+	if timeout > 0 {
+		params.Timeout = timeout
+	}
+
+	return c.doCmd(rpc_exec, params)
 }
 
 // 中断执行
 func (c *MQTTCoupler) Stop() (shell.ExecuteResult, error) {
-	return shell.ExecuteResult{}, nil
+	params := struct {
+		ClientID string `json:"client_id"`
+	}{
+		ClientID: c.clientID,
+	}
+
+	return c.doCmd(rpc_stop, params)
 }
 
 // 关闭退出
 func (c *MQTTCoupler) Quit() (shell.ExecuteResult, error) {
-	return shell.ExecuteResult{}, nil
+	params := struct {
+		ClientID string `json:"client_id"`
+	}{
+		ClientID: c.clientID,
+	}
+
+	return c.doCmd(rpc_quit, params)
 }

+ 5 - 3
utils/shell/execute.go

@@ -17,7 +17,7 @@ import (
 )
 
 const (
-	defaultTimeout    = 3 * time.Second
+	DefaultTimeout    = 3 * time.Second
 	gracePeriod       = 2 * time.Second
 	forceKillWait     = 2 * time.Second
 	exitTimeoutCode   = 124
@@ -137,7 +137,7 @@ func executeInternal(p ExecuteParams, onStart func(pg *processGroup)) (*ExecuteR
 
 	timeout := time.Duration(p.Timeout) * time.Second
 	if timeout <= 0 {
-		timeout = defaultTimeout
+		timeout = DefaultTimeout
 	}
 
 	ctx, cancel := context.WithTimeout(context.Background(), timeout)
@@ -196,9 +196,11 @@ func executeInternal(p ExecuteParams, onStart func(pg *processGroup)) (*ExecuteR
 			break
 		}
 
+		timer := time.NewTimer(forceKillWait)
+		defer timer.Stop()
 		select {
 		case <-done:
-		case <-time.After(forceKillWait):
+		case <-timer.C:
 			finalErr = ErrExecutorLostControl
 		}
 	}