rule_linux.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. package netlink
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "github.com/vishvananda/netlink/nl"
  8. "golang.org/x/sys/unix"
  9. )
  10. const FibRuleInvert = 0x2
  11. // RuleAdd adds a rule to the system.
  12. // Equivalent to: ip rule add
  13. func RuleAdd(rule *Rule) error {
  14. return pkgHandle.RuleAdd(rule)
  15. }
  16. // RuleAdd adds a rule to the system.
  17. // Equivalent to: ip rule add
  18. func (h *Handle) RuleAdd(rule *Rule) error {
  19. req := h.newNetlinkRequest(unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
  20. return ruleHandle(rule, req)
  21. }
  22. // RuleDel deletes a rule from the system.
  23. // Equivalent to: ip rule del
  24. func RuleDel(rule *Rule) error {
  25. return pkgHandle.RuleDel(rule)
  26. }
  27. // RuleDel deletes a rule from the system.
  28. // Equivalent to: ip rule del
  29. func (h *Handle) RuleDel(rule *Rule) error {
  30. req := h.newNetlinkRequest(unix.RTM_DELRULE, unix.NLM_F_ACK)
  31. return ruleHandle(rule, req)
  32. }
  33. func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
  34. msg := nl.NewRtMsg()
  35. msg.Family = unix.AF_INET
  36. msg.Protocol = unix.RTPROT_BOOT
  37. msg.Scope = unix.RT_SCOPE_UNIVERSE
  38. msg.Table = unix.RT_TABLE_UNSPEC
  39. msg.Type = rule.Type // usually 0, same as unix.RTN_UNSPEC
  40. if msg.Type == 0 && req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
  41. msg.Type = unix.RTN_UNICAST
  42. }
  43. if rule.Invert {
  44. msg.Flags |= FibRuleInvert
  45. }
  46. if rule.Family != 0 {
  47. msg.Family = uint8(rule.Family)
  48. }
  49. if rule.Table >= 0 && rule.Table < 256 {
  50. msg.Table = uint8(rule.Table)
  51. }
  52. if rule.Tos != 0 {
  53. msg.Tos = uint8(rule.Tos)
  54. }
  55. var dstFamily uint8
  56. var rtAttrs []*nl.RtAttr
  57. if rule.Dst != nil && rule.Dst.IP != nil {
  58. dstLen, _ := rule.Dst.Mask.Size()
  59. msg.Dst_len = uint8(dstLen)
  60. msg.Family = uint8(nl.GetIPFamily(rule.Dst.IP))
  61. dstFamily = msg.Family
  62. var dstData []byte
  63. if msg.Family == unix.AF_INET {
  64. dstData = rule.Dst.IP.To4()
  65. } else {
  66. dstData = rule.Dst.IP.To16()
  67. }
  68. rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
  69. }
  70. if rule.Src != nil && rule.Src.IP != nil {
  71. msg.Family = uint8(nl.GetIPFamily(rule.Src.IP))
  72. if dstFamily != 0 && dstFamily != msg.Family {
  73. return fmt.Errorf("source and destination ip are not the same IP family")
  74. }
  75. srcLen, _ := rule.Src.Mask.Size()
  76. msg.Src_len = uint8(srcLen)
  77. var srcData []byte
  78. if msg.Family == unix.AF_INET {
  79. srcData = rule.Src.IP.To4()
  80. } else {
  81. srcData = rule.Src.IP.To16()
  82. }
  83. rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, srcData))
  84. }
  85. req.AddData(msg)
  86. for i := range rtAttrs {
  87. req.AddData(rtAttrs[i])
  88. }
  89. if rule.Priority >= 0 {
  90. b := make([]byte, 4)
  91. native.PutUint32(b, uint32(rule.Priority))
  92. req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
  93. }
  94. if rule.Mark != 0 || rule.Mask != nil {
  95. b := make([]byte, 4)
  96. native.PutUint32(b, rule.Mark)
  97. req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
  98. }
  99. if rule.Mask != nil {
  100. b := make([]byte, 4)
  101. native.PutUint32(b, *rule.Mask)
  102. req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
  103. }
  104. if rule.Flow >= 0 {
  105. b := make([]byte, 4)
  106. native.PutUint32(b, uint32(rule.Flow))
  107. req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
  108. }
  109. if rule.TunID > 0 {
  110. b := make([]byte, 4)
  111. native.PutUint32(b, uint32(rule.TunID))
  112. req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
  113. }
  114. if rule.Table >= 256 {
  115. b := make([]byte, 4)
  116. native.PutUint32(b, uint32(rule.Table))
  117. req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
  118. }
  119. if msg.Table > 0 {
  120. if rule.SuppressPrefixlen >= 0 {
  121. b := make([]byte, 4)
  122. native.PutUint32(b, uint32(rule.SuppressPrefixlen))
  123. req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
  124. }
  125. if rule.SuppressIfgroup >= 0 {
  126. b := make([]byte, 4)
  127. native.PutUint32(b, uint32(rule.SuppressIfgroup))
  128. req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
  129. }
  130. }
  131. if rule.IifName != "" {
  132. req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName+"\x00")))
  133. }
  134. if rule.OifName != "" {
  135. req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName+"\x00")))
  136. }
  137. if rule.Goto >= 0 {
  138. msg.Type = nl.FR_ACT_GOTO
  139. b := make([]byte, 4)
  140. native.PutUint32(b, uint32(rule.Goto))
  141. req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
  142. }
  143. if rule.IPProto > 0 {
  144. b := make([]byte, 4)
  145. native.PutUint32(b, uint32(rule.IPProto))
  146. req.AddData(nl.NewRtAttr(nl.FRA_IP_PROTO, b))
  147. }
  148. if rule.Dport != nil {
  149. b := rule.Dport.toRtAttrData()
  150. req.AddData(nl.NewRtAttr(nl.FRA_DPORT_RANGE, b))
  151. }
  152. if rule.Sport != nil {
  153. b := rule.Sport.toRtAttrData()
  154. req.AddData(nl.NewRtAttr(nl.FRA_SPORT_RANGE, b))
  155. }
  156. if rule.UIDRange != nil {
  157. b := rule.UIDRange.toRtAttrData()
  158. req.AddData(nl.NewRtAttr(nl.FRA_UID_RANGE, b))
  159. }
  160. if rule.Protocol > 0 {
  161. req.AddData(nl.NewRtAttr(nl.FRA_PROTOCOL, nl.Uint8Attr(rule.Protocol)))
  162. }
  163. _, err := req.Execute(unix.NETLINK_ROUTE, 0)
  164. return err
  165. }
  166. // RuleList lists rules in the system.
  167. // Equivalent to: ip rule list
  168. //
  169. // If the returned error is [ErrDumpInterrupted], results may be inconsistent
  170. // or incomplete.
  171. func RuleList(family int) ([]Rule, error) {
  172. return pkgHandle.RuleList(family)
  173. }
  174. // RuleList lists rules in the system.
  175. // Equivalent to: ip rule list
  176. //
  177. // If the returned error is [ErrDumpInterrupted], results may be inconsistent
  178. // or incomplete.
  179. func (h *Handle) RuleList(family int) ([]Rule, error) {
  180. return h.RuleListFiltered(family, nil, 0)
  181. }
  182. // RuleListFiltered gets a list of rules in the system filtered by the
  183. // specified rule template `filter`.
  184. // Equivalent to: ip rule list
  185. //
  186. // If the returned error is [ErrDumpInterrupted], results may be inconsistent
  187. // or incomplete.
  188. func RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
  189. return pkgHandle.RuleListFiltered(family, filter, filterMask)
  190. }
  191. // RuleListFiltered lists rules in the system.
  192. // Equivalent to: ip rule list
  193. //
  194. // If the returned error is [ErrDumpInterrupted], results may be inconsistent
  195. // or incomplete.
  196. func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
  197. req := h.newNetlinkRequest(unix.RTM_GETRULE, unix.NLM_F_DUMP|unix.NLM_F_REQUEST)
  198. msg := nl.NewIfInfomsg(family)
  199. req.AddData(msg)
  200. msgs, executeErr := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWRULE)
  201. if executeErr != nil && !errors.Is(executeErr, ErrDumpInterrupted) {
  202. return nil, executeErr
  203. }
  204. var res = make([]Rule, 0)
  205. for i := range msgs {
  206. msg := nl.DeserializeRtMsg(msgs[i])
  207. attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
  208. if err != nil {
  209. return nil, err
  210. }
  211. rule := NewRule()
  212. rule.Priority = 0 // The default priority from kernel
  213. rule.Invert = msg.Flags&FibRuleInvert > 0
  214. rule.Family = int(msg.Family)
  215. rule.Tos = uint(msg.Tos)
  216. for j := range attrs {
  217. switch attrs[j].Attr.Type {
  218. case unix.RTA_TABLE:
  219. rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
  220. case nl.FRA_SRC:
  221. rule.Src = &net.IPNet{
  222. IP: attrs[j].Value,
  223. Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
  224. }
  225. case nl.FRA_DST:
  226. rule.Dst = &net.IPNet{
  227. IP: attrs[j].Value,
  228. Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
  229. }
  230. case nl.FRA_FWMARK:
  231. rule.Mark = native.Uint32(attrs[j].Value[0:4])
  232. case nl.FRA_FWMASK:
  233. mask := native.Uint32(attrs[j].Value[0:4])
  234. rule.Mask = &mask
  235. case nl.FRA_TUN_ID:
  236. rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
  237. case nl.FRA_IIFNAME:
  238. rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
  239. case nl.FRA_OIFNAME:
  240. rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
  241. case nl.FRA_SUPPRESS_PREFIXLEN:
  242. i := native.Uint32(attrs[j].Value[0:4])
  243. if i != 0xffffffff {
  244. rule.SuppressPrefixlen = int(i)
  245. }
  246. case nl.FRA_SUPPRESS_IFGROUP:
  247. i := native.Uint32(attrs[j].Value[0:4])
  248. if i != 0xffffffff {
  249. rule.SuppressIfgroup = int(i)
  250. }
  251. case nl.FRA_FLOW:
  252. rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
  253. case nl.FRA_GOTO:
  254. rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
  255. case nl.FRA_PRIORITY:
  256. rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
  257. case nl.FRA_IP_PROTO:
  258. rule.IPProto = int(native.Uint32(attrs[j].Value[0:4]))
  259. case nl.FRA_DPORT_RANGE:
  260. rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
  261. case nl.FRA_SPORT_RANGE:
  262. rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
  263. case nl.FRA_UID_RANGE:
  264. rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8]))
  265. case nl.FRA_PROTOCOL:
  266. rule.Protocol = uint8(attrs[j].Value[0])
  267. }
  268. }
  269. if filter != nil {
  270. switch {
  271. case filterMask&RT_FILTER_SRC != 0 &&
  272. (rule.Src == nil || rule.Src.String() != filter.Src.String()):
  273. continue
  274. case filterMask&RT_FILTER_DST != 0 &&
  275. (rule.Dst == nil || rule.Dst.String() != filter.Dst.String()):
  276. continue
  277. case filterMask&RT_FILTER_TABLE != 0 &&
  278. filter.Table != unix.RT_TABLE_UNSPEC && rule.Table != filter.Table:
  279. continue
  280. case filterMask&RT_FILTER_TOS != 0 && rule.Tos != filter.Tos:
  281. continue
  282. case filterMask&RT_FILTER_PRIORITY != 0 && rule.Priority != filter.Priority:
  283. continue
  284. case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark:
  285. continue
  286. case filterMask&RT_FILTER_MASK != 0 && !ptrEqual(rule.Mask, filter.Mask):
  287. continue
  288. }
  289. }
  290. res = append(res, *rule)
  291. }
  292. return res, executeErr
  293. }
  294. func (pr *RulePortRange) toRtAttrData() []byte {
  295. b := [][]byte{make([]byte, 2), make([]byte, 2)}
  296. native.PutUint16(b[0], pr.Start)
  297. native.PutUint16(b[1], pr.End)
  298. return bytes.Join(b, []byte{})
  299. }
  300. func (pr *RuleUIDRange) toRtAttrData() []byte {
  301. b := [][]byte{make([]byte, 4), make([]byte, 4)}
  302. native.PutUint32(b[0], pr.Start)
  303. native.PutUint32(b[1], pr.End)
  304. return bytes.Join(b, []byte{})
  305. }
  306. func ptrEqual(a, b *uint32) bool {
  307. if a == b {
  308. return true
  309. }
  310. if (a == nil) || (b == nil) {
  311. return false
  312. }
  313. return *a == *b
  314. }
  315. func (r Rule) typeString() string {
  316. switch r.Type {
  317. case unix.RTN_UNSPEC: // zero
  318. return ""
  319. case unix.RTN_UNICAST:
  320. return ""
  321. case unix.RTN_LOCAL:
  322. return "local"
  323. case unix.RTN_BROADCAST:
  324. return "broadcast"
  325. case unix.RTN_ANYCAST:
  326. return "anycast"
  327. case unix.RTN_MULTICAST:
  328. return "multicast"
  329. case unix.RTN_BLACKHOLE:
  330. return "blackhole"
  331. case unix.RTN_UNREACHABLE:
  332. return "unreachable"
  333. case unix.RTN_PROHIBIT:
  334. return "prohibit"
  335. case unix.RTN_THROW:
  336. return "throw"
  337. case unix.RTN_NAT:
  338. return "nat"
  339. case unix.RTN_XRESOLVE:
  340. return "xresolve"
  341. default:
  342. return fmt.Sprintf("type(0x%x)", r.Type)
  343. }
  344. }