package forwarder import ( "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/pem" "github.com/tobyxdd/hysteria/internal/forwarder" "math/big" "net" ) type server struct { config ServerConfig callbacks ServerCallbacks entries map[string]*forwarder.QUICServer } func NewServer(config ServerConfig, callbacks ServerCallbacks) Server { // Fix config first if config.TLSConfig == nil { config.TLSConfig = generateInsecureTLSConfig() } if config.MaxSpeedPerClient == nil { config.MaxSpeedPerClient = &Speed{0, 0} } if config.MaxReceiveWindowPerConnection == 0 { config.MaxReceiveWindowPerConnection = defaultReceiveWindowConn } if config.MaxReceiveWindowPerClient == 0 { config.MaxReceiveWindowPerClient = defaultReceiveWindow } if config.MaxConnectionPerClient <= 0 { config.MaxConnectionPerClient = defaultMaxClientConn } return &server{config: config, callbacks: callbacks, entries: make(map[string]*forwarder.QUICServer)} } func (s *server) Add(listenAddr, remoteAddr string) error { qs, err := forwarder.NewQUICServer(listenAddr, remoteAddr, s.config.BannerMessage, s.config.TLSConfig, s.config.MaxSpeedPerClient.SendBPS, s.config.MaxSpeedPerClient.ReceiveBPS, s.config.MaxReceiveWindowPerConnection, s.config.MaxReceiveWindowPerClient, s.config.MaxConnectionPerClient, forwarder.CongestionFactory(s.config.CongestionFactory), func(addr net.Addr, name string, sSend uint64, sRecv uint64) { if s.callbacks.ClientConnectedCallback != nil { s.callbacks.ClientConnectedCallback(listenAddr, addr, name, sSend, sRecv) } }, func(addr net.Addr, name string, err error) { if s.callbacks.ClientDisconnectedCallback != nil { s.callbacks.ClientDisconnectedCallback(listenAddr, addr, name, err) } }, func(addr net.Addr, name string, id int) { if s.callbacks.ClientNewStreamCallback != nil { s.callbacks.ClientNewStreamCallback(listenAddr, addr, name, id) } }, func(addr net.Addr, name string, id int, err error) { if s.callbacks.ClientStreamClosedCallback != nil { s.callbacks.ClientStreamClosedCallback(listenAddr, addr, name, id, err) } }, func(remoteAddr string, err error) { if s.callbacks.TCPErrorCallback != nil { s.callbacks.TCPErrorCallback(listenAddr, remoteAddr, err) } }, ) if err != nil { return err } s.entries[listenAddr] = qs return nil } func (s *server) Remove(listenAddr string) error { defer delete(s.entries, listenAddr) if qs, ok := s.entries[listenAddr]; ok && qs != nil { return qs.Close() } return nil } func (s *server) Stats() map[string]Stats { r := make(map[string]Stats, len(s.entries)) for laddr, sv := range s.entries { addr, in, out := sv.Stats() r[laddr] = Stats{ RemoteAddr: addr, inboundBytes: in, outboundBytes: out, } } return r } func generateInsecureTLSConfig() *tls.Config { key, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { panic(err) } template := x509.Certificate{SerialNumber: big.NewInt(1)} certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) if err != nil { panic(err) } keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { panic(err) } return &tls.Config{ Certificates: []tls.Certificate{tlsCert}, NextProtos: []string{TLSAppProtocol}, } }