From 5fe18e020d8dec508743e87797bd4514dc2d5fd3 Mon Sep 17 00:00:00 2001 From: Senis John Date: Wed, 4 Oct 2023 09:45:52 +0800 Subject: [PATCH] Refactor sspanel and test files in the API Fixed typos, changed variable naming to follow camel case convention, simplified the error handling, and refactored the function names in sspanel.go for improved readability and consistency. Accordingly, updated the test function names in sspanel_test.go to match the changes made in the main file. Additionally, the if condition in ParseSSPanelNodeInfo was streamlined to improve the function's readability and error handling. The changes will help in better understanding the code while debugging or adding new features. --- api/sspanel/sspanel.go | 120 ++++++++++++++++++------------------ api/sspanel/sspanel_test.go | 8 +-- 2 files changed, 65 insertions(+), 63 deletions(-) diff --git a/api/sspanel/sspanel.go b/api/sspanel/sspanel.go index bbcfdf7..bd7fe3d 100644 --- a/api/sspanel/sspanel.go +++ b/api/sspanel/sspanel.go @@ -44,7 +44,7 @@ type APIClient struct { eTags map[string]string } -// New creat a api instance +// New create api instance func New(apiConfig *api.Config) *APIClient { client := resty.New() @@ -55,12 +55,14 @@ func New(apiConfig *api.Config) *APIClient { client.SetTimeout(5 * time.Second) } client.OnError(func(req *resty.Request, err error) { - if v, ok := err.(*resty.ResponseError); ok { + var v *resty.ResponseError + if errors.As(err, &v) { // v.Response contains the last response from the server // v.Err contains the original error log.Print(v.Err) } }) + client.SetBaseURL(apiConfig.APIHost) // Create Key for each requests client.SetQueryParam("key", apiConfig.Key) @@ -153,7 +155,7 @@ func (c *APIClient) parseResponse(res *resty.Response, path string, err error) ( return response, nil } -// GetNodeInfo will pull NodeInfo Config from sspanel +// GetNodeInfo will pull NodeInfo Config from ssPanel func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) { path := fmt.Sprintf("/mod_mu/nodes/%d/info", c.NodeID) res, err := c.client.R(). @@ -181,20 +183,9 @@ func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) { return nil, fmt.Errorf("unmarshal %s failed: %s", reflect.TypeOf(nodeInfoResponse), err) } - // New sspanel API + // New ssPanel API c.version = nodeInfoResponse.Version - disableCustomConfig := c.DisableCustomConfig - if nodeInfoResponse.Version != "" && !disableCustomConfig { - // Check if custom_config is empty - if configString, err := json.Marshal(nodeInfoResponse.CustomConfig); err != nil || string(configString) == "[]" { - log.Printf("custom_config is empty! take config from address now.") - disableCustomConfig = true - } - } else { - disableCustomConfig = true - } - - if !disableCustomConfig { + if !c.DisableCustomConfig { nodeInfo, err = c.ParseSSPanelNodeInfo(nodeInfoResponse) if err != nil { res, _ := json.Marshal(nodeInfoResponse) @@ -223,7 +214,7 @@ func (c *APIClient) GetNodeInfo() (nodeInfo *api.NodeInfo, err error) { return nodeInfo, nil } -// GetUserList will pull user form sspanel +// GetUserList will pull user form ssPanel func (c *APIClient) GetUserList() (UserList *[]api.UserInfo, err error) { path := "/mod_mu/users" res, err := c.client.R(). @@ -259,18 +250,18 @@ func (c *APIClient) GetUserList() (UserList *[]api.UserInfo, err error) { return userList, nil } -// ReportNodeStatus reports the node status to the sspanel +// ReportNodeStatus reports the node status to the ssPanel func (c *APIClient) ReportNodeStatus(nodeStatus *api.NodeStatus) (err error) { // Determine whether a status report is in need if compareVersion(c.version, "2023.2") == -1 { path := fmt.Sprintf("/mod_mu/nodes/%d/info", c.NodeID) - systemload := SystemLoad{ + systemLoad := SystemLoad{ Uptime: strconv.FormatUint(nodeStatus.Uptime, 10), Load: fmt.Sprintf("%.2f %.2f %.2f", nodeStatus.CPU/100, nodeStatus.Mem/100, nodeStatus.Disk/100), } res, err := c.client.R(). - SetBody(systemload). + SetBody(systemLoad). SetResult(&Response{}). ForceContentType("application/json"). Post(path) @@ -343,7 +334,7 @@ func (c *APIClient) ReportUserTraffic(userTraffic *[]api.UserTraffic) error { return nil } -// GetNodeRule will pull the audit rule form sspanel +// GetNodeRule will pull the audit rule form ssPanel func (c *APIClient) GetNodeRule() (*[]api.DetectRule, error) { ruleList := c.LocalRuleList path := "/mod_mu/func/detect_rules" @@ -407,12 +398,12 @@ func (c *APIClient) ReportIllegal(detectResultList *[]api.DetectResult) error { return nil } -// ParseV2rayNodeResponse parse the response for the given nodeinfor format +// ParseV2rayNodeResponse parse the response for the given node info format func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { var enableTLS bool var path, host, transportProtocol, serviceName, HeaderType string var header json.RawMessage - var speedlimit uint64 = 0 + var speedLimit uint64 = 0 if nodeInfoResponse.RawServerString == "" { return nil, fmt.Errorf("no server info in response") } @@ -464,9 +455,9 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) ( } } if c.SpeedLimit > 0 { - speedlimit = uint64((c.SpeedLimit * 1000000) / 8) + speedLimit = uint64((c.SpeedLimit * 1000000) / 8) } else { - speedlimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) + speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) } if HeaderType != "" { @@ -479,11 +470,11 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) ( } // Create GeneralNodeInfo - nodeinfo := &api.NodeInfo{ + nodeInfo := &api.NodeInfo{ NodeType: c.NodeType, NodeID: c.NodeID, Port: port, - SpeedLimit: speedlimit, + SpeedLimit: speedLimit, AlterID: alterID, TransportProtocol: transportProtocol, EnableTLS: enableTLS, @@ -495,13 +486,13 @@ func (c *APIClient) ParseV2rayNodeResponse(nodeInfoResponse *NodeInfoResponse) ( Header: header, } - return nodeinfo, nil + return nodeInfo, nil } -// ParseSSNodeResponse parse the response for the given nodeinfor format +// ParseSSNodeResponse parse the response for the given node info format func (c *APIClient) ParseSSNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { var port uint32 = 0 - var speedlimit uint64 = 0 + var speedLimit uint64 = 0 var method string path := "/mod_mu/users" res, err := c.client.R(). @@ -533,28 +524,28 @@ func (c *APIClient) ParseSSNodeResponse(nodeInfoResponse *NodeInfoResponse) (*ap } if c.SpeedLimit > 0 { - speedlimit = uint64((c.SpeedLimit * 1000000) / 8) + speedLimit = uint64((c.SpeedLimit * 1000000) / 8) } else { - speedlimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) + speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) } // Create GeneralNodeInfo - nodeinfo := &api.NodeInfo{ + nodeInfo := &api.NodeInfo{ NodeType: c.NodeType, NodeID: c.NodeID, Port: port, - SpeedLimit: speedlimit, + SpeedLimit: speedLimit, TransportProtocol: "tcp", CypherMethod: method, } - return nodeinfo, nil + return nodeInfo, nil } -// ParseSSPluginNodeResponse parse the response for the given nodeinfor format +// ParseSSPluginNodeResponse parse the response for the given node info format func (c *APIClient) ParseSSPluginNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { var enableTLS bool var path, host, transportProtocol string - var speedlimit uint64 = 0 + var speedLimit uint64 = 0 serverConf := strings.Split(nodeInfoResponse.RawServerString, ";") parsedPort, err := strconv.ParseInt(serverConf[1], 10, 32) @@ -595,27 +586,27 @@ func (c *APIClient) ParseSSPluginNodeResponse(nodeInfoResponse *NodeInfoResponse } } if c.SpeedLimit > 0 { - speedlimit = uint64((c.SpeedLimit * 1000000) / 8) + speedLimit = uint64((c.SpeedLimit * 1000000) / 8) } else { - speedlimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) + speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) } // Create GeneralNodeInfo - nodeinfo := &api.NodeInfo{ + nodeInfo := &api.NodeInfo{ NodeType: c.NodeType, NodeID: c.NodeID, Port: port, - SpeedLimit: speedlimit, + SpeedLimit: speedLimit, TransportProtocol: transportProtocol, EnableTLS: enableTLS, Path: path, Host: host, } - return nodeinfo, nil + return nodeInfo, nil } -// ParseTrojanNodeResponse parse the response for the given nodeinfor format +// ParseTrojanNodeResponse parse the response for the given node info format func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { // 域名或IP;port=连接端口#偏移端口|host=xx // gz.aaa.com;port=443#12345|host=hk.aaa.com @@ -672,7 +663,7 @@ func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse) speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) } // Create GeneralNodeInfo - nodeinfo := &api.NodeInfo{ + nodeInfo := &api.NodeInfo{ NodeType: c.NodeType, NodeID: c.NodeID, Port: port, @@ -683,10 +674,10 @@ func (c *APIClient) ParseTrojanNodeResponse(nodeInfoResponse *NodeInfoResponse) ServiceName: serviceName, } - return nodeinfo, nil + return nodeInfo, nil } -// ParseUserListResponse parse the response for the given nodeinfo format +// ParseUserListResponse parse the response for the given node info format func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[]api.UserInfo, error) { c.access.Lock() // Clear Last report log @@ -696,7 +687,7 @@ func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[] }() var deviceLimit, localDeviceLimit int = 0, 0 - var speedlimit uint64 = 0 + var speedLimit uint64 = 0 var userList []api.UserInfo for _, user := range *userInfoResponse { if c.DeviceLimit > 0 { @@ -724,16 +715,16 @@ func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[] } if c.SpeedLimit > 0 { - speedlimit = uint64((c.SpeedLimit * 1000000) / 8) + speedLimit = uint64((c.SpeedLimit * 1000000) / 8) } else { - speedlimit = uint64((user.SpeedLimit * 1000000) / 8) + speedLimit = uint64((user.SpeedLimit * 1000000) / 8) } userList = append(userList, api.UserInfo{ UID: user.ID, Email: user.Email, UUID: user.UUID, Passwd: user.Passwd, - SpeedLimit: speedlimit, + SpeedLimit: speedLimit, DeviceLimit: deviceLimit, Port: user.Port, Method: user.Method, @@ -747,22 +738,33 @@ func (c *APIClient) ParseUserListResponse(userInfoResponse *[]UserResponse) (*[] return &userList, nil } -// ParseSSPanelNodeInfo parse the response for the given nodeinfor format +// ParseSSPanelNodeInfo parse the response for the given node info format // Only used for SSPanel version >= 2021.11 func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*api.NodeInfo, error) { - - var speedlimit uint64 = 0 + var speedLimit uint64 = 0 var EnableTLS, EnableVless bool var AlterID uint16 = 0 var TLSType, transportProtocol string + if nodeInfoResponse.Version == "" { + return nil, errors.New("panel version must be 2021.11 or above") + } + + // Check if custom_config is valid + if len(nodeInfoResponse.CustomConfig) == 0 { + return nil, errors.New("custom_config is empty, disable custom config") + } + nodeConfig := new(CustomConfig) - json.Unmarshal(nodeInfoResponse.CustomConfig, nodeConfig) + err := json.Unmarshal(nodeInfoResponse.CustomConfig, nodeConfig) + if err != nil { + return nil, fmt.Errorf("custom_config is error: %v", err) + } if c.SpeedLimit > 0 { - speedlimit = uint64((c.SpeedLimit * 1000000) / 8) + speedLimit = uint64((c.SpeedLimit * 1000000) / 8) } else { - speedlimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) + speedLimit = uint64((nodeInfoResponse.SpeedLimit * 1000000) / 8) } parsedPort, err := strconv.ParseInt(nodeConfig.OffsetPortNode, 10, 32) @@ -813,11 +815,11 @@ func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*a } // Create GeneralNodeInfo - nodeinfo := &api.NodeInfo{ + nodeInfo := &api.NodeInfo{ NodeType: c.NodeType, NodeID: c.NodeID, Port: port, - SpeedLimit: speedlimit, + SpeedLimit: speedLimit, AlterID: AlterID, TransportProtocol: transportProtocol, Host: nodeConfig.Host, @@ -830,7 +832,7 @@ func (c *APIClient) ParseSSPanelNodeInfo(nodeInfoResponse *NodeInfoResponse) (*a Header: nodeConfig.Header, } - return nodeinfo, nil + return nodeInfo, nil } func compareVersion(version1, version2 string) int { diff --git a/api/sspanel/sspanel_test.go b/api/sspanel/sspanel_test.go index 433b0d0..44b4a4a 100644 --- a/api/sspanel/sspanel_test.go +++ b/api/sspanel/sspanel_test.go @@ -19,7 +19,7 @@ func CreateClient() api.API { return client } -func TestGetV2rayNodeinfo(t *testing.T) { +func TestGetV2rayNodeInfo(t *testing.T) { client := CreateClient() nodeInfo, err := client.GetNodeInfo() @@ -29,7 +29,7 @@ func TestGetV2rayNodeinfo(t *testing.T) { t.Log(nodeInfo) } -func TestGetSSNodeinfo(t *testing.T) { +func TestGetSSNodeInfo(t *testing.T) { apiConfig := &api.Config{ APIHost: "http://127.0.0.1:667", Key: "123", @@ -44,7 +44,7 @@ func TestGetSSNodeinfo(t *testing.T) { t.Log(nodeInfo) } -func TestGetTrojanNodeinfo(t *testing.T) { +func TestGetTrojanNodeInfo(t *testing.T) { apiConfig := &api.Config{ APIHost: "http://127.0.0.1:667", Key: "123", @@ -59,7 +59,7 @@ func TestGetTrojanNodeinfo(t *testing.T) { t.Log(nodeInfo) } -func TestGetSSinfo(t *testing.T) { +func TestGetSSInfo(t *testing.T) { client := CreateClient() nodeInfo, err := client.GetNodeInfo()