package core import ( "context" "crypto/tls" "errors" "fmt" "github.com/lucas-clemente/quic-go" "github.com/lunixbochs/struc" "github.com/prometheus/client_golang/prometheus" "github.com/tobyxdd/hysteria/pkg/acl" "net" ) type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error) type UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32) type UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error) type Server struct { sendBPS, recvBPS uint64 congestionFactory CongestionFactory disableUDP bool aclEngine *acl.Engine authFunc AuthFunc tcpRequestFunc TCPRequestFunc tcpErrorFunc TCPErrorFunc udpRequestFunc UDPRequestFunc udpErrorFunc UDPErrorFunc upCounterVec, downCounterVec *prometheus.CounterVec listener quic.Listener } func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, promRegistry *prometheus.Registry) (*Server, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } udpConn, err := net.ListenUDP("udp", udpAddr) if err != nil { return nil, err } var listener quic.Listener if obfuscator != nil { // Wrap PacketConn with obfuscator listener, err = quic.Listen(&obfsUDPConn{ Orig: udpConn, Obfuscator: obfuscator, }, tlsConfig, quicConfig) if err != nil { return nil, err } } else { listener, err = quic.Listen(udpConn, tlsConfig, quicConfig) if err != nil { return nil, err } } s := &Server{ listener: listener, sendBPS: sendBPS, recvBPS: recvBPS, congestionFactory: congestionFactory, disableUDP: disableUDP, aclEngine: aclEngine, authFunc: authFunc, tcpRequestFunc: tcpRequestFunc, tcpErrorFunc: tcpErrorFunc, udpRequestFunc: udpRequestFunc, udpErrorFunc: udpErrorFunc, } if promRegistry != nil { s.upCounterVec = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "hysteria_traffic_uplink_bytes_total", }, []string{"auth"}) s.downCounterVec = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "hysteria_traffic_downlink_bytes_total", }, []string{"auth"}) promRegistry.MustRegister(s.upCounterVec, s.downCounterVec) } return s, nil } func (s *Server) Serve() error { for { cs, err := s.listener.Accept(context.Background()) if err != nil { return err } go s.handleClient(cs) } } func (s *Server) Close() error { return s.listener.Close() } func (s *Server) handleClient(cs quic.Session) { // Expect the client to create a control stream to send its own information ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout) stream, err := cs.AcceptStream(ctx) ctxCancel() if err != nil { _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") return } // Handle the control stream auth, ok, err := s.handleControlStream(cs, stream) if err != nil { _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") return } if !ok { _ = cs.CloseWithError(closeErrorCodeAuth, "auth error") return } // Start accepting streams and messages sc := newServerClient(cs, auth, s.disableUDP, s.aclEngine, s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.upCounterVec, s.downCounterVec) sc.Run() _ = cs.CloseWithError(closeErrorCodeGeneric, "") } // Auth & negotiate speed func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) { // Check version vb := make([]byte, 1) _, err := stream.Read(vb) if err != nil { return nil, false, err } if vb[0] != protocolVersion { return nil, false, fmt.Errorf("unsupported protocol version %d, expecting %d", vb[0], protocolVersion) } // Parse client hello var ch clientHello err = struc.Unpack(stream, &ch) if err != nil { return nil, false, err } // Speed if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 { return nil, false, errors.New("invalid rate from client") } serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS if s.sendBPS > 0 && serverSendBPS > s.sendBPS { serverSendBPS = s.sendBPS } if s.recvBPS > 0 && serverRecvBPS > s.recvBPS { serverRecvBPS = s.recvBPS } // Auth ok, msg := s.authFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS) // Response err = struc.Pack(stream, &serverHello{ OK: ok, Rate: transmissionRate{ SendBPS: serverSendBPS, RecvBPS: serverRecvBPS, }, Message: msg, }) if err != nil { return nil, false, err } // Set the congestion accordingly if ok && s.congestionFactory != nil { cs.SetCongestionControl(s.congestionFactory(serverSendBPS)) } return ch.Auth, ok, nil }