mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-08-14 21:31:45 +00:00
refactor: re-org packages
This commit is contained in:
183
core/cs/server.go
Normal file
183
core/cs/server.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package cs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/apernet/hysteria/core/congestion"
|
||||
|
||||
"github.com/apernet/hysteria/core/acl"
|
||||
"github.com/apernet/hysteria/core/pmtud"
|
||||
"github.com/apernet/hysteria/core/transport"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lunixbochs/struc"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type (
|
||||
ConnectFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
|
||||
DisconnectFunc func(addr net.Addr, auth []byte, err error)
|
||||
TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string)
|
||||
TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error)
|
||||
UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32)
|
||||
UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error)
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
transport *transport.ServerTransport
|
||||
sendBPS, recvBPS uint64
|
||||
disableUDP bool
|
||||
aclEngine *acl.Engine
|
||||
|
||||
connectFunc ConnectFunc
|
||||
disconnectFunc DisconnectFunc
|
||||
tcpRequestFunc TCPRequestFunc
|
||||
tcpErrorFunc TCPErrorFunc
|
||||
udpRequestFunc UDPRequestFunc
|
||||
udpErrorFunc UDPErrorFunc
|
||||
|
||||
upCounterVec, downCounterVec *prometheus.CounterVec
|
||||
connGaugeVec *prometheus.GaugeVec
|
||||
|
||||
pktConn net.PacketConn
|
||||
listener quic.Listener
|
||||
}
|
||||
|
||||
func NewServer(tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
pktConn net.PacketConn, transport *transport.ServerTransport,
|
||||
sendBPS uint64, recvBPS uint64, disableUDP bool, aclEngine *acl.Engine,
|
||||
connectFunc ConnectFunc, disconnectFunc DisconnectFunc,
|
||||
tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc,
|
||||
udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, promRegistry *prometheus.Registry,
|
||||
) (*Server, error) {
|
||||
quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery
|
||||
listener, err := quic.Listen(pktConn, tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
_ = pktConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
s := &Server{
|
||||
pktConn: pktConn,
|
||||
listener: listener,
|
||||
transport: transport,
|
||||
sendBPS: sendBPS,
|
||||
recvBPS: recvBPS,
|
||||
disableUDP: disableUDP,
|
||||
aclEngine: aclEngine,
|
||||
connectFunc: connectFunc,
|
||||
disconnectFunc: disconnectFunc,
|
||||
tcpRequestFunc: tcpRequestFunc,
|
||||
tcpErrorFunc: tcpErrorFunc,
|
||||
udpRequestFunc: udpRequestFunc,
|
||||
udpErrorFunc: udpErrorFunc,
|
||||
}
|
||||
if promRegistry != nil {
|
||||
s.upCounterVec = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "hysteria_traffic_uplink_bytes_total",
|
||||
}, []string{"auth"})
|
||||
s.downCounterVec = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "hysteria_traffic_downlink_bytes_total",
|
||||
}, []string{"auth"})
|
||||
s.connGaugeVec = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "hysteria_active_conn",
|
||||
}, []string{"auth"})
|
||||
promRegistry.MustRegister(s.upCounterVec, s.downCounterVec, s.connGaugeVec)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) Serve() error {
|
||||
for {
|
||||
cc, err := s.listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handleClient(cc)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
err := s.listener.Close()
|
||||
_ = s.pktConn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handleClient(cc quic.Connection) {
|
||||
// Expect the client to create a control stream to send its own information
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
|
||||
stream, err := cc.AcceptStream(ctx)
|
||||
ctxCancel()
|
||||
if err != nil {
|
||||
_ = qErrorProtocol.Send(cc)
|
||||
return
|
||||
}
|
||||
// Handle the control stream
|
||||
auth, ok, err := s.handleControlStream(cc, stream)
|
||||
if err != nil {
|
||||
_ = qErrorProtocol.Send(cc)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
_ = qErrorAuth.Send(cc)
|
||||
return
|
||||
}
|
||||
// Start accepting streams and messages
|
||||
sc := newServerClient(cc, s.transport, auth, s.disableUDP, s.aclEngine,
|
||||
s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc,
|
||||
s.upCounterVec, s.downCounterVec, s.connGaugeVec)
|
||||
err = sc.Run()
|
||||
_ = qErrorGeneric.Send(cc)
|
||||
s.disconnectFunc(cc.RemoteAddr(), auth, err)
|
||||
}
|
||||
|
||||
// Auth & negotiate speed
|
||||
func (s *Server) handleControlStream(cc quic.Connection, stream quic.Stream) ([]byte, bool, error) {
|
||||
// Check version
|
||||
vb := make([]byte, 1)
|
||||
_, err := stream.Read(vb)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if vb[0] != protocolVersion {
|
||||
return nil, false, fmt.Errorf("unsupported protocol version %d, expecting %d", vb[0], protocolVersion)
|
||||
}
|
||||
// Parse client hello
|
||||
var ch clientHello
|
||||
err = struc.Unpack(stream, &ch)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
// Speed
|
||||
if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 {
|
||||
return nil, false, errors.New("invalid rate from client")
|
||||
}
|
||||
serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS
|
||||
if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
|
||||
serverSendBPS = s.sendBPS
|
||||
}
|
||||
if s.recvBPS > 0 && serverRecvBPS > s.recvBPS {
|
||||
serverRecvBPS = s.recvBPS
|
||||
}
|
||||
// Auth
|
||||
ok, msg := s.connectFunc(cc.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS)
|
||||
// Response
|
||||
err = struc.Pack(stream, &serverHello{
|
||||
OK: ok,
|
||||
Rate: maxRate{
|
||||
SendBPS: serverSendBPS,
|
||||
RecvBPS: serverRecvBPS,
|
||||
},
|
||||
Message: msg,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
// Set the congestion accordingly
|
||||
if ok {
|
||||
cc.SetCongestionControl(congestion.NewBrutalSender(serverSendBPS))
|
||||
}
|
||||
return ch.Auth, ok, nil
|
||||
}
|
Reference in New Issue
Block a user