XrayR/app/mydispatcher/default.go
Senis John f35a056cc6
Update several dependencies for bug fixes and improvements
This commit includes updates to several dependencies in go.sum file in order to incorporate bug fixes and performance improvements. The updated dependencies include 'cloud.google.com/go/compute', 'github.com/gaukas/godicttls', 'github.com/google/pprof' and 'github.com/xtls/xray-core' among others.

These updates are crucial to ensure the project runs efficiently and securely, with all known issues in these dependencies addressed in their latest versions.
2023-09-04 08:53:08 +08:00

469 lines
13 KiB
Go

package mydispatcher
//go:generate go run github.com/xtls/xray-core/common/errors/errorgen
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
routingSession "github.com/xtls/xray-core/features/routing/session"
"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/pipe"
"github.com/XrayR-project/XrayR/common/limiter"
"github.com/XrayR-project/XrayR/common/rule"
)
var errSniffingTimeout = newError("timeout on sniffing")
type cachedReader struct {
sync.Mutex
reader *pipe.Reader
cache buf.MultiBuffer
}
func (r *cachedReader) Cache(b *buf.Buffer) {
mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
r.Lock()
if !mb.IsEmpty() {
r.cache, _ = buf.MergeMulti(r.cache, mb)
}
b.Clear()
rawBytes := b.Extend(buf.Size)
n := r.cache.Copy(rawBytes)
b.Resize(0, int32(n))
r.Unlock()
}
func (r *cachedReader) readInternal() buf.MultiBuffer {
r.Lock()
defer r.Unlock()
if r.cache != nil && !r.cache.IsEmpty() {
mb := r.cache
r.cache = nil
return mb
}
return nil
}
func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
mb := r.readInternal()
if mb != nil {
return mb, nil
}
return r.reader.ReadMultiBuffer()
}
func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
mb := r.readInternal()
if mb != nil {
return mb, nil
}
return r.reader.ReadMultiBufferTimeout(timeout)
}
func (r *cachedReader) Interrupt() {
r.Lock()
if r.cache != nil {
r.cache = buf.ReleaseMulti(r.cache)
}
r.Unlock()
r.reader.Interrupt()
}
// DefaultDispatcher is a default implementation of Dispatcher.
type DefaultDispatcher struct {
ohm outbound.Manager
router routing.Router
policy policy.Manager
stats stats.Manager
dns dns.Client
fdns dns.FakeDNSEngine
Limiter *limiter.Limiter
RuleManager *rule.Manager
}
func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
d := new(DefaultDispatcher)
if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
d.fdns = fdns
})
return d.Init(config.(*Config), om, router, pm, sm, dc)
}); err != nil {
return nil, err
}
return d, nil
}))
}
// Init initializes DefaultDispatcher.
func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error {
d.ohm = om
d.router = router
d.policy = pm
d.stats = sm
d.Limiter = limiter.New()
d.RuleManager = rule.New()
d.dns = dns
return nil
}
// Type implements common.HasType.
func (*DefaultDispatcher) Type() interface{} {
return routing.DispatcherType()
}
// Start implements common.Runnable.
func (*DefaultDispatcher) Start() error {
return nil
}
// Close implements common.Closable.
func (*DefaultDispatcher) Close() error {
return nil
}
func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link, error) {
opt := pipe.OptionsFromContext(ctx)
uplinkReader, uplinkWriter := pipe.New(opt...)
downlinkReader, downlinkWriter := pipe.New(opt...)
inboundLink := &transport.Link{
Reader: downlinkReader,
Writer: uplinkWriter,
}
outboundLink := &transport.Link{
Reader: uplinkReader,
Writer: downlinkWriter,
}
sessionInbound := session.InboundFromContext(ctx)
var user *protocol.MemoryUser
if sessionInbound != nil {
user = sessionInbound.User
}
if user != nil && len(user.Email) > 0 {
// Speed Limit and Device Limit
bucket, ok, reject := d.Limiter.GetUserBucket(sessionInbound.Tag, user.Email, sessionInbound.Source.Address.IP().String())
if reject {
newError("Devices reach the limit: ", user.Email).AtWarning().WriteToLog()
common.Close(outboundLink.Writer)
common.Close(inboundLink.Writer)
common.Interrupt(outboundLink.Reader)
common.Interrupt(inboundLink.Reader)
return nil, nil, newError("Devices reach the limit: ", user.Email)
}
if ok {
inboundLink.Writer = d.Limiter.RateWriter(inboundLink.Writer, bucket)
outboundLink.Writer = d.Limiter.RateWriter(outboundLink.Writer, bucket)
}
p := d.policy.ForLevel(user.Level)
if p.Stats.UserUplink {
name := "user>>>" + user.Email + ">>>traffic>>>uplink"
if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
inboundLink.Writer = &SizeStatWriter{
Counter: c,
Writer: inboundLink.Writer,
}
}
}
if p.Stats.UserDownlink {
name := "user>>>" + user.Email + ">>>traffic>>>downlink"
if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
outboundLink.Writer = &SizeStatWriter{
Counter: c,
Writer: outboundLink.Writer,
}
}
}
}
return inboundLink, outboundLink, nil
}
func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
domain := result.Domain()
for _, d := range request.ExcludeForDomain {
if strings.ToLower(domain) == d {
return false
}
}
protocolString := result.Protocol()
if resComp, ok := result.(SnifferResultComposite); ok {
protocolString = resComp.ProtocolForDomainResult()
}
for _, p := range request.OverrideDestinationForProtocol {
if strings.HasPrefix(protocolString, p) {
return true
}
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) {
newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx))
return true
}
if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
if resultSubset.IsProtoSubsetOf(p) {
return true
}
}
}
return false
}
// Dispatch implements routing.Dispatcher.
func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
if !destination.IsValid() {
panic("Dispatcher: Invalid destination.")
}
ob := &session.Outbound{
Target: destination,
}
ctx = session.ContextWithOutbound(ctx, ob)
content := session.ContentFromContext(ctx)
if content == nil {
content = new(session.Content)
ctx = session.ContextWithContent(ctx, content)
}
sniffingRequest := content.SniffingRequest
inbound, outbound, err := d.getLink(ctx, destination.Network, sniffingRequest)
if err != nil {
return nil, err
}
if !sniffingRequest.Enabled {
go d.routedDispatch(ctx, outbound, destination)
} else {
go func() {
cReader := &cachedReader{
reader: outbound.Reader.(*pipe.Reader),
}
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
if err == nil {
content.Protocol = result.Protocol()
}
if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain)
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
ob.RouteTarget = destination
} else {
ob.Target = destination
}
}
d.routedDispatch(ctx, outbound, destination)
}()
}
return inbound, nil
}
// DispatchLink implements routing.Dispatcher.
func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
if !destination.IsValid() {
return newError("Dispatcher: Invalid destination.")
}
ob := &session.Outbound{
Target: destination,
}
ctx = session.ContextWithOutbound(ctx, ob)
content := session.ContentFromContext(ctx)
if content == nil {
content = new(session.Content)
ctx = session.ContextWithContent(ctx, content)
}
sniffingRequest := content.SniffingRequest
if !sniffingRequest.Enabled {
go d.routedDispatch(ctx, outbound, destination)
} else {
go func() {
cReader := &cachedReader{
reader: outbound.Reader.(*pipe.Reader),
}
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
if err == nil {
content.Protocol = result.Protocol()
}
if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
domain := result.Domain()
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
destination.Address = net.ParseAddress(domain)
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
ob.RouteTarget = destination
} else {
ob.Target = destination
}
}
d.routedDispatch(ctx, outbound, destination)
}()
}
return nil
}
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
payload := buf.New()
defer payload.Release()
sniffer := NewSniffer(ctx)
metaresult, metadataErr := sniffer.SniffMetadata(ctx)
if metadataOnly {
return metaresult, metadataErr
}
contentResult, contentErr := func() (SniffResult, error) {
totalAttempt := 0
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
totalAttempt++
if totalAttempt > 2 {
return nil, errSniffingTimeout
}
cReader.Cache(payload)
if !payload.IsEmpty() {
result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
if err != common.ErrNoClue {
return result, err
}
}
if payload.IsFull() {
return nil, errUnknownContent
}
}
}
}()
if contentErr != nil && metadataErr == nil {
return metaresult, nil
}
if contentErr == nil && metadataErr == nil {
return CompositeResult(metaresult, contentResult), nil
}
return contentResult, contentErr
}
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
ob := session.OutboundFromContext(ctx)
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := hosts.LookupHosts(ob.Target.String())
if proxied != nil {
ro := ob.RouteTarget == destination
destination.Address = *proxied
if ro {
ob.RouteTarget = destination
} else {
ob.Target = destination
}
}
}
var handler outbound.Handler
// Check if domain and protocol hit the rule
sessionInbound := session.InboundFromContext(ctx)
// Whether the inbound connection contains a user
if sessionInbound.User != nil {
if d.RuleManager.Detect(sessionInbound.Tag, destination.String(), sessionInbound.User.Email) {
newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog()
newError("destination is reject by rule")
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
}
routingLink := routingSession.AsRoutingContext(ctx)
inTag := routingLink.GetInboundTag()
isPickRoute := 0
if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
ctx = session.SetForcedOutboundTagToContext(ctx, "")
if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
isPickRoute = 1
newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
handler = h
} else {
newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx))
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
} else if d.router != nil {
if route, err := d.router.PickRoute(routingLink); err == nil {
outTag := route.GetOutboundTag()
if h := d.ohm.GetHandler(outTag); h != nil {
isPickRoute = 2
newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
handler = h
} else {
newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
}
} else {
newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
}
}
if handler == nil {
handler = d.ohm.GetHandler(inTag) // Default outbound handler tag should be as same as the inbound tag
}
// If there is no outbound with tag as same as the inbound tag
if handler == nil {
handler = d.ohm.GetDefaultHandler()
}
if handler == nil {
newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
if tag := handler.Tag(); tag != "" {
if inTag == "" {
accessMessage.Detour = tag
} else if isPickRoute == 1 {
accessMessage.Detour = inTag + " ==> " + tag
} else if isPickRoute == 2 {
accessMessage.Detour = inTag + " -> " + tag
} else {
accessMessage.Detour = inTag + " >> " + tag
}
}
log.Record(accessMessage)
}
handler.Dispatch(ctx, link)
}