Refactor sspanel and test files in the API

Fixed typos, changed variable naming to follow camel case convention, simplified the error handling, and refactored the function names in sspanel.go for improved readability and consistency. Accordingly, updated the test function names in sspanel_test.go to match the changes made in the main file. Additionally, the if condition in ParseSSPanelNodeInfo was streamlined to improve the function's readability and error handling. The changes will help in better understanding the code while debugging or adding new features.
This commit is contained in:
Senis John 2023-10-04 09:45:52 +08:00
parent a0f2730bb2
commit 5fe18e020d
No known key found for this signature in database
GPG Key ID: 845E9E4727C3E1A4
2 changed files with 65 additions and 63 deletions

View File

@ -44,7 +44,7 @@ type APIClient struct {
eTags map[string]string eTags map[string]string
} }
// New creat a api instance // New create api instance
func New(apiConfig *api.Config) *APIClient { func New(apiConfig *api.Config) *APIClient {
client := resty.New() client := resty.New()
@ -55,12 +55,14 @@ func New(apiConfig *api.Config) *APIClient {
client.SetTimeout(5 * time.Second) client.SetTimeout(5 * time.Second)
} }
client.OnError(func(req *resty.Request, err error) { client.OnError(func(req *resty.Request, err error) {
if v, ok := err.(*resty.ResponseError); ok { var v *resty.ResponseError
if errors.As(err, &v) {
// v.Response contains the last response from the server // v.Response contains the last response from the server
// v.Err contains the original error // v.Err contains the original error
log.Print(v.Err) log.Print(v.Err)
} }
}) })
client.SetBaseURL(apiConfig.APIHost) client.SetBaseURL(apiConfig.APIHost)
// Create Key for each requests // Create Key for each requests
client.SetQueryParam("key", apiConfig.Key) client.SetQueryParam("key", apiConfig.Key)
@ -153,7 +155,7 @@ func (c *APIClient) parseResponse(res *resty.Response, path string, err error) (
return response, nil return response, nil
} }
// GetNodeInfo will pull NodeInfo Config from sspanel // GetNodeInfo will pull NodeInfo Config from ssPanel
func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) { func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) {
path := fmt.Sprintf("/mod_mu/nodes/%d/info", c.NodeID) path := fmt.Sprintf("/mod_mu/nodes/%d/info", c.NodeID)
res, err := c.client.R(). res, err := c.client.R().
@ -181,20 +183,9 @@ func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) {
return nil, fmt.Errorf("unmarshal %s failed: %s", reflect.TypeOf(nodeInfoResponse), err) return nil, fmt.Errorf("unmarshal %s failed: %s", reflect.TypeOf(nodeInfoResponse), err)
} }
// New sspanel API // New ssPanel API
c.version = nodeInfoResponse.Version c.version = nodeInfoResponse.Version
disableCustomConfig := c.DisableCustomConfig if !c.DisableCustomConfig {
if nodeInfoResponse.Version != "" && !disableCustomConfig {
// Check if custom_config is empty
if configString, err := json.Marshal(nodeInfoResponse.CustomConfig); err != nil || string(configString) == "[]" {
log.Printf("custom_config is empty! take config from address now.")
disableCustomConfig = true
}
} else {
disableCustomConfig = true
}
if !disableCustomConfig {
nodeInfo, err = c.ParseSSPanelNodeInfo(nodeInfoResponse) nodeInfo, err = c.ParseSSPanelNodeInfo(nodeInfoResponse)
if err != nil { if err != nil {
res, _ := json.Marshal(nodeInfoResponse) res, _ := json.Marshal(nodeInfoResponse)
@ -223,7 +214,7 @@ func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) {
return nodeInfo, nil return nodeInfo, nil
} }
// GetUserList will pull user form sspanel // GetUserList will pull user form ssPanel
func (c *APIClient) GetUserList() (UserList *[]api.UserInfo, err error) { func (c *APIClient) GetUserList() (UserList *[]api.UserInfo, err error) {
path := "/mod_mu/users" path := "/mod_mu/users"
res, err := c.client.R(). res, err := c.client.R().
@ -259,18 +250,18 @@ func (c *APIClient) GetUserList() (UserList *[]api.UserInfo, err error) {
return userList, nil return userList, nil
} }
// ReportNodeStatus reports the node status to the sspanel // ReportNodeStatus reports the node status to the ssPanel
func (c *APIClient) ReportNodeStatus(nodeStatus *api.NodeStatus) (err error) { func (c *APIClient) ReportNodeStatus(nodeStatus *api.NodeStatus) (err error) {
// Determine whether a status report is in need // Determine whether a status report is in need
if compareVersion(c.version, "2023.2") == -1 { if compareVersion(c.version, "2023.2") == -1 {
path := fmt.Sprintf("/mod_mu/nodes/%d/info", c.NodeID) path := fmt.Sprintf("/mod_mu/nodes/%d/info", c.NodeID)
systemload := SystemLoad{ systemLoad := SystemLoad{
Uptime: strconv.FormatUint(nodeStatus.Uptime, 10), Uptime: strconv.FormatUint(nodeStatus.Uptime, 10),
Load: fmt.Sprintf("%.2f %.2f %.2f", nodeStatus.CPU/100, nodeStatus.Mem/100, nodeStatus.Disk/100), Load: fmt.Sprintf("%.2f %.2f %.2f", nodeStatus.CPU/100, nodeStatus.Mem/100, nodeStatus.Disk/100),
} }
res, err := c.client.R(). res, err := c.client.R().
SetBody(systemload). SetBody(systemLoad).
SetResult(&Response{}). SetResult(&Response{}).
ForceContentType("application/json"). ForceContentType("application/json").
Post(path) Post(path)
@ -343,7 +334,7 @@ func (c *APIClient) ReportUserTraffic(userTraffic *[]api.UserTraffic) error {
return nil return nil
} }
// GetNodeRule will pull the audit rule form sspanel // GetNodeRule will pull the audit rule form ssPanel
func (c *APIClient) GetNodeRule() (*[]api.DetectRule, error) { func (c *APIClient) GetNodeRule() (*[]api.DetectRule, error) {
ruleList := c.LocalRuleList ruleList := c.LocalRuleList
path := "/mod_mu/func/detect_rules" path := "/mod_mu/func/detect_rules"
@ -407,12 +398,12 @@ func (c *APIClient) ReportIllegal(detectResultList *[]api.DetectResult) error {
return nil return nil
} }
// ParseV2rayNodeResponse parse the response for the given nodeinfor format // ParseV2rayNodeResponse parse the response for the given node info format
func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) {
var enableTLS bool var enableTLS bool
var path, host, transportProtocol, serviceName, HeaderType string var path, host, transportProtocol, serviceName, HeaderType string
var header json.RawMessage var header json.RawMessage
var speedlimit uint64 = 0 var speedLimit uint64 = 0
if nodeInfoResponse.RawServerString == "" { if nodeInfoResponse.RawServerString == "" {
return nil, fmt.Errorf("no server info in response") return nil, fmt.Errorf("no server info in response")
} }
@ -464,9 +455,9 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) (
} }
} }
if c.SpeedLimit > 0 { if c.SpeedLimit > 0 {
speedlimit = uint64((c.SpeedLimit * 1000000) / 8) speedLimit = uint64((c.SpeedLimit * 1000000) / 8)
} else { } else {
speedlimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8)
} }
if HeaderType != "" { if HeaderType != "" {
@ -479,11 +470,11 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) (
} }
// Create GeneralNodeInfo // Create GeneralNodeInfo
nodeinfo := &api.NodeInfo{ nodeInfo := &api.NodeInfo{
NodeType: c.NodeType, NodeType: c.NodeType,
NodeID: c.NodeID, NodeID: c.NodeID,
Port: port, Port: port,
SpeedLimit: speedlimit, SpeedLimit: speedLimit,
AlterID: alterID, AlterID: alterID,
TransportProtocol: transportProtocol, TransportProtocol: transportProtocol,
EnableTLS: enableTLS, EnableTLS: enableTLS,
@ -495,13 +486,13 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) (
Header: header, Header: header,
} }
return nodeinfo, nil return nodeInfo, nil
} }
// ParseSSNodeResponse parse the response for the given nodeinfor format // ParseSSNodeResponse parse the response for the given node info format
func (c *APIClient) ParseSSNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { func (c *APIClient) ParseSSNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) {
var port uint32 = 0 var port uint32 = 0
var speedlimit uint64 = 0 var speedLimit uint64 = 0
var method string var method string
path := "/mod_mu/users" path := "/mod_mu/users"
res, err := c.client.R(). res, err := c.client.R().
@ -533,28 +524,28 @@ func (c *APIClient) ParseSSNodeResponse(nodeInfoResponse *NodeInfoResponse) (*ap
} }
if c.SpeedLimit > 0 { if c.SpeedLimit > 0 {
speedlimit = uint64((c.SpeedLimit * 1000000) / 8) speedLimit = uint64((c.SpeedLimit * 1000000) / 8)
} else { } else {
speedlimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8)
} }
// Create GeneralNodeInfo // Create GeneralNodeInfo
nodeinfo := &api.NodeInfo{ nodeInfo := &api.NodeInfo{
NodeType: c.NodeType, NodeType: c.NodeType,
NodeID: c.NodeID, NodeID: c.NodeID,
Port: port, Port: port,
SpeedLimit: speedlimit, SpeedLimit: speedLimit,
TransportProtocol: "tcp", TransportProtocol: "tcp",
CypherMethod: method, CypherMethod: method,
} }
return nodeinfo, nil return nodeInfo, nil
} }
// ParseSSPluginNodeResponse parse the response for the given nodeinfor format // ParseSSPluginNodeResponse parse the response for the given node info format
func (c *APIClient) ParseSSPluginNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { func (c *APIClient) ParseSSPluginNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) {
var enableTLS bool var enableTLS bool
var path, host, transportProtocol string var path, host, transportProtocol string
var speedlimit uint64 = 0 var speedLimit uint64 = 0
serverConf := strings.Split(nodeInfoResponse.RawServerString, ";") serverConf := strings.Split(nodeInfoResponse.RawServerString, ";")
parsedPort, err := strconv.ParseInt(serverConf[1], 10, 32) parsedPort, err := strconv.ParseInt(serverConf[1], 10, 32)
@ -595,27 +586,27 @@ func (c *APIClient) ParseSSPluginNodeResponse(nodeInfoResponse *NodeInfoResponse
} }
} }
if c.SpeedLimit > 0 { if c.SpeedLimit > 0 {
speedlimit = uint64((c.SpeedLimit * 1000000) / 8) speedLimit = uint64((c.SpeedLimit * 1000000) / 8)
} else { } else {
speedlimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8)
} }
// Create GeneralNodeInfo // Create GeneralNodeInfo
nodeinfo := &api.NodeInfo{ nodeInfo := &api.NodeInfo{
NodeType: c.NodeType, NodeType: c.NodeType,
NodeID: c.NodeID, NodeID: c.NodeID,
Port: port, Port: port,
SpeedLimit: speedlimit, SpeedLimit: speedLimit,
TransportProtocol: transportProtocol, TransportProtocol: transportProtocol,
EnableTLS: enableTLS, EnableTLS: enableTLS,
Path: path, Path: path,
Host: host, Host: host,
} }
return nodeinfo, nil return nodeInfo, nil
} }
// ParseTrojanNodeResponse parse the response for the given nodeinfor format // ParseTrojanNodeResponse parse the response for the given node info format
func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) {
// 域名或IP;port=连接端口#偏移端口|host=xx // 域名或IP;port=连接端口#偏移端口|host=xx
// gz.aaa.com;port=443#12345|host=hk.aaa.com // gz.aaa.com;port=443#12345|host=hk.aaa.com
@ -672,7 +663,7 @@ func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse)
speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8)
} }
// Create GeneralNodeInfo // Create GeneralNodeInfo
nodeinfo := &api.NodeInfo{ nodeInfo := &api.NodeInfo{
NodeType: c.NodeType, NodeType: c.NodeType,
NodeID: c.NodeID, NodeID: c.NodeID,
Port: port, Port: port,
@ -683,10 +674,10 @@ func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse)
ServiceName: serviceName, ServiceName: serviceName,
} }
return nodeinfo, nil return nodeInfo, nil
} }
// ParseUserListResponse parse the response for the given nodeinfo format // ParseUserListResponse parse the response for the given node info format
func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[]api.UserInfo, error) { func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[]api.UserInfo, error) {
c.access.Lock() c.access.Lock()
// Clear Last report log // Clear Last report log
@ -696,7 +687,7 @@ func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[]
}() }()
var deviceLimit, localDeviceLimit int = 0, 0 var deviceLimit, localDeviceLimit int = 0, 0
var speedlimit uint64 = 0 var speedLimit uint64 = 0
var userList []api.UserInfo var userList []api.UserInfo
for _, user := range *userInfoResponse { for _, user := range *userInfoResponse {
if c.DeviceLimit > 0 { if c.DeviceLimit > 0 {
@ -724,16 +715,16 @@ func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[]
} }
if c.SpeedLimit > 0 { if c.SpeedLimit > 0 {
speedlimit = uint64((c.SpeedLimit * 1000000) / 8) speedLimit = uint64((c.SpeedLimit * 1000000) / 8)
} else { } else {
speedlimit = uint64((user.SpeedLimit * 1000000) / 8) speedLimit = uint64((user.SpeedLimit * 1000000) / 8)
} }
userList = append(userList, api.UserInfo{ userList = append(userList, api.UserInfo{
UID: user.ID, UID: user.ID,
Email: user.Email, Email: user.Email,
UUID: user.UUID, UUID: user.UUID,
Passwd: user.Passwd, Passwd: user.Passwd,
SpeedLimit: speedlimit, SpeedLimit: speedLimit,
DeviceLimit: deviceLimit, DeviceLimit: deviceLimit,
Port: user.Port, Port: user.Port,
Method: user.Method, Method: user.Method,
@ -747,22 +738,33 @@ func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[]
return &userList, nil return &userList, nil
} }
// ParseSSPanelNodeInfo parse the response for the given nodeinfor format // ParseSSPanelNodeInfo parse the response for the given node info format
// Only used for SSPanel version >= 2021.11 // Only used for SSPanel version >= 2021.11
func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) {
var speedLimit uint64 = 0
var speedlimit uint64 = 0
var EnableTLS, EnableVless bool var EnableTLS, EnableVless bool
var AlterID uint16 = 0 var AlterID uint16 = 0
var TLSType, transportProtocol string var TLSType, transportProtocol string
if nodeInfoResponse.Version == "" {
return nil, errors.New("panel version must be 2021.11 or above")
}
// Check if custom_config is valid
if len(nodeInfoResponse.CustomConfig) == 0 {
return nil, errors.New("custom_config is empty, disable custom config")
}
nodeConfig := new(CustomConfig) nodeConfig := new(CustomConfig)
json.Unmarshal(nodeInfoResponse.CustomConfig, nodeConfig) err := json.Unmarshal(nodeInfoResponse.CustomConfig, nodeConfig)
if err != nil {
return nil, fmt.Errorf("custom_config is error: %v", err)
}
if c.SpeedLimit > 0 { if c.SpeedLimit > 0 {
speedlimit = uint64((c.SpeedLimit * 1000000) / 8) speedLimit = uint64((c.SpeedLimit * 1000000) / 8)
} else { } else {
speedlimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8)
} }
parsedPort, err := strconv.ParseInt(nodeConfig.OffsetPortNode, 10, 32) parsedPort, err := strconv.ParseInt(nodeConfig.OffsetPortNode, 10, 32)
@ -813,11 +815,11 @@ func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*a
} }
// Create GeneralNodeInfo // Create GeneralNodeInfo
nodeinfo := &api.NodeInfo{ nodeInfo := &api.NodeInfo{
NodeType: c.NodeType, NodeType: c.NodeType,
NodeID: c.NodeID, NodeID: c.NodeID,
Port: port, Port: port,
SpeedLimit: speedlimit, SpeedLimit: speedLimit,
AlterID: AlterID, AlterID: AlterID,
TransportProtocol: transportProtocol, TransportProtocol: transportProtocol,
Host: nodeConfig.Host, Host: nodeConfig.Host,
@ -830,7 +832,7 @@ func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*a
Header: nodeConfig.Header, Header: nodeConfig.Header,
} }
return nodeinfo, nil return nodeInfo, nil
} }
func compareVersion(version1, version2 string) int { func compareVersion(version1, version2 string) int {

View File

@ -19,7 +19,7 @@ func CreateClient() api.API {
return client return client
} }
func TestGetV2rayNodeinfo(t *testing.T) { func TestGetV2rayNodeInfo(t *testing.T) {
client := CreateClient() client := CreateClient()
nodeInfo, err := client.GetNodeInfo() nodeInfo, err := client.GetNodeInfo()
@ -29,7 +29,7 @@ func TestGetV2rayNodeinfo(t *testing.T) {
t.Log(nodeInfo) t.Log(nodeInfo)
} }
func TestGetSSNodeinfo(t *testing.T) { func TestGetSSNodeInfo(t *testing.T) {
apiConfig := &api.Config{ apiConfig := &api.Config{
APIHost: "http://127.0.0.1:667", APIHost: "http://127.0.0.1:667",
Key: "123", Key: "123",
@ -44,7 +44,7 @@ func TestGetSSNodeinfo(t *testing.T) {
t.Log(nodeInfo) t.Log(nodeInfo)
} }
func TestGetTrojanNodeinfo(t *testing.T) { func TestGetTrojanNodeInfo(t *testing.T) {
apiConfig := &api.Config{ apiConfig := &api.Config{
APIHost: "http://127.0.0.1:667", APIHost: "http://127.0.0.1:667",
Key: "123", Key: "123",
@ -59,7 +59,7 @@ func TestGetTrojanNodeinfo(t *testing.T) {
t.Log(nodeInfo) t.Log(nodeInfo)
} }
func TestGetSSinfo(t *testing.T) { func TestGetSSInfo(t *testing.T) {
client := CreateClient() client := CreateClient()
nodeInfo, err := client.GetNodeInfo() nodeInfo, err := client.GetNodeInfo()