diff --git a/app/cmd/server.go b/app/cmd/server.go index 386c4f2..435d2c0 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -20,6 +20,7 @@ import ( "github.com/apernet/hysteria/extras/auth" "github.com/apernet/hysteria/extras/obfs" "github.com/apernet/hysteria/extras/outbounds" + "github.com/apernet/hysteria/extras/trafficlogger" ) var serverCmd = &cobra.Command{ @@ -46,6 +47,7 @@ type serverConfig struct { Resolver serverConfigResolver `mapstructure:"resolver"` ACL serverConfigACL `mapstructure:"acl"` Outbounds []serverConfigOutboundEntry `mapstructure:"outbounds"` + TrafficStats serverConfigTrafficStats `mapstructure:"trafficStats"` Masquerade serverConfigMasquerade `mapstructure:"masquerade"` } @@ -160,6 +162,10 @@ type serverConfigOutboundEntry struct { SOCKS5 serverConfigOutboundSOCKS5 `mapstructure:"socks5"` } +type serverConfigTrafficStats struct { + Listen string `mapstructure:"listen"` +} + type serverConfigMasqueradeFile struct { Dir string `mapstructure:"dir"` } @@ -504,6 +510,15 @@ func (c *serverConfig) fillEventLogger(hyConfig *server.Config) error { return nil } +func (c *serverConfig) fillTrafficLogger(hyConfig *server.Config) error { + if c.TrafficStats.Listen != "" { + tss := trafficlogger.NewTrafficStatsServer() + hyConfig.TrafficLogger = tss + go runTrafficStatsServer(c.TrafficStats.Listen, tss) + } + return nil +} + func (c *serverConfig) fillMasqHandler(hyConfig *server.Config) error { switch strings.ToLower(c.Masquerade.Type) { case "", "404": @@ -557,6 +572,7 @@ func (c *serverConfig) Config() (*server.Config, error) { c.fillUDPIdleTimeout, c.fillAuthenticator, c.fillEventLogger, + c.fillTrafficLogger, c.fillMasqHandler, } for _, f := range fillers { @@ -594,6 +610,13 @@ func runServer(cmd *cobra.Command, args []string) { } } +func runTrafficStatsServer(listen string, handler http.Handler) { + logger.Info("traffic stats server up and running", zap.String("listen", listen)) + if err := http.ListenAndServe(listen, handler); err != nil { + logger.Fatal("failed to serve traffic stats", zap.Error(err)) + } +} + func geoipDownloadFunc(filename, url string) { logger.Info("downloading GeoIP database", zap.String("filename", filename), zap.String("url", url)) } diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index f215521..45d7e33 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -124,6 +124,9 @@ func TestServerConfig(t *testing.T) { }, }, }, + TrafficStats: serverConfigTrafficStats{ + Listen: ":9999", + }, Masquerade: serverConfigMasquerade{ Type: "proxy", File: serverConfigMasqueradeFile{ diff --git a/app/cmd/server_test.yaml b/app/cmd/server_test.yaml index be3c083..432fb9f 100644 --- a/app/cmd/server_test.yaml +++ b/app/cmd/server_test.yaml @@ -92,6 +92,9 @@ outbounds: username: hackerman password: Elliot Alderson +trafficStats: + listen: :9999 + masquerade: type: proxy file: diff --git a/extras/trafficlogger/http.go b/extras/trafficlogger/http.go new file mode 100644 index 0000000..bd22a0d --- /dev/null +++ b/extras/trafficlogger/http.go @@ -0,0 +1,114 @@ +package trafficlogger + +import ( + "encoding/json" + "net/http" + "strconv" + "sync" + + "github.com/apernet/hysteria/core/server" +) + +const ( + indexHTML = ` Hysteria Traffic Stats API Server

This is a Hysteria Traffic Stats API server.

Check the documentation for usage.

` +) + +// TrafficStatsServer implements both server.TrafficLogger and http.Handler +// to provide a simple HTTP API to get the traffic stats per user. +type TrafficStatsServer interface { + server.TrafficLogger + http.Handler +} + +func NewTrafficStatsServer() TrafficStatsServer { + return &trafficStatsServerImpl{ + StatsMap: make(map[string]*trafficStatsEntry), + KickMap: make(map[string]struct{}), + } +} + +type trafficStatsServerImpl struct { + Mutex sync.RWMutex + StatsMap map[string]*trafficStatsEntry + KickMap map[string]struct{} +} + +type trafficStatsEntry struct { + Tx uint64 `json:"tx"` + Rx uint64 `json:"rx"` +} + +func (s *trafficStatsServerImpl) Log(id string, tx, rx uint64) (ok bool) { + s.Mutex.Lock() + defer s.Mutex.Unlock() + + _, ok = s.KickMap[id] + if ok { + delete(s.KickMap, id) + return false + } + + entry, ok := s.StatsMap[id] + if !ok { + entry = &trafficStatsEntry{} + s.StatsMap[id] = entry + } + entry.Tx += tx + entry.Rx += rx + + return true +} + +func (s *trafficStatsServerImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && r.URL.Path == "/" { + _, _ = w.Write([]byte(indexHTML)) + return + } + if r.Method == http.MethodGet && r.URL.Path == "/traffic" { + s.getTraffic(w, r) + return + } + if r.Method == http.MethodPost && r.URL.Path == "/kick" { + s.kick(w, r) + return + } + http.NotFound(w, r) +} + +func (s *trafficStatsServerImpl) getTraffic(w http.ResponseWriter, r *http.Request) { + bClear, _ := strconv.ParseBool(r.URL.Query().Get("clear")) + var jb []byte + var err error + if bClear { + s.Mutex.Lock() + jb, err = json.Marshal(s.StatsMap) + s.StatsMap = make(map[string]*trafficStatsEntry) + s.Mutex.Unlock() + } else { + s.Mutex.RLock() + jb, err = json.Marshal(s.StatsMap) + s.Mutex.RUnlock() + } + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + _, _ = w.Write(jb) +} + +func (s *trafficStatsServerImpl) kick(w http.ResponseWriter, r *http.Request) { + var ids []string + err := json.NewDecoder(r.Body).Decode(&ids) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.Mutex.Lock() + for _, id := range ids { + s.KickMap[id] = struct{}{} + } + s.Mutex.Unlock() + + w.WriteHeader(http.StatusOK) +}