405 lines
11 KiB
Go
405 lines
11 KiB
Go
package main
|
||
|
||
import (
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"crypto/x509/pkix"
|
||
"encoding/pem"
|
||
"flag"
|
||
"fmt"
|
||
"io"
|
||
"io/ioutil"
|
||
"log"
|
||
"math/big"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"path/filepath"
|
||
"time"
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// 添加全局变量
|
||
var (
|
||
proxyAddress string // 默认值,可以通过命令行参数修改
|
||
listenAddress string // 新增的地址代理参数
|
||
dnsPort string // 新增DNS服务器端口
|
||
)
|
||
|
||
func main() {
|
||
// 解析命令行参数
|
||
flag.StringVar(&listenAddress, "address", "127.0.0.1", "HTTP和HTTPS监听地址")
|
||
flag.StringVar(&proxyAddress, "address-proxy", "127.0.0.1:7897", "代理服务器地址")
|
||
flag.StringVar(&dnsPort, "dns-port", "53", "DNS服务器端口")
|
||
flag.Parse()
|
||
|
||
// 检查CA证书是否存在
|
||
sslDir := "SSL"
|
||
caFile := filepath.Join(sslDir, "ca.pem")
|
||
keyFile := filepath.Join(sslDir, "ca-key.pem")
|
||
if !fileExists(caFile) || !fileExists(keyFile) {
|
||
generateCACert(caFile, keyFile)
|
||
}
|
||
|
||
// 加载CA证书和私钥
|
||
caCert, caKey, err := loadCA(caFile, keyFile)
|
||
if err != nil {
|
||
log.Fatal("加载CA证书和私钥失败:", err)
|
||
}
|
||
|
||
// 设置HTTP服务器
|
||
go func() {
|
||
httpServer := &http.Server{
|
||
Addr: listenAddress + ":80",
|
||
Handler: http.HandlerFunc(httpHandler),
|
||
}
|
||
log.Printf("启动HTTP服务器,监听地址: %s:80\n", listenAddress)
|
||
if err := httpServer.ListenAndServe(); err != nil {
|
||
log.Fatal("HTTP服务器错误:", err)
|
||
}
|
||
}()
|
||
|
||
// 设置HTTPS服务器
|
||
httpsServer := &http.Server{
|
||
Addr: listenAddress + ":443",
|
||
TLSConfig: &tls.Config{
|
||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||
sni := hello.ServerName
|
||
fmt.Printf("接收到HTTPS请求,SNI: %s\n", sni)
|
||
return generateCert(caCert, caKey, sni)
|
||
},
|
||
},
|
||
Handler: http.HandlerFunc(httpsHandler),
|
||
}
|
||
|
||
log.Printf("启动HTTPS服务器,监听地址: %s:443, 代理地址: %s...\n", listenAddress, proxyAddress)
|
||
if err := httpsServer.ListenAndServeTLS("", ""); err != nil {
|
||
log.Fatal("HTTPS服务器错误:", err)
|
||
}
|
||
|
||
// 启动DNS服务器
|
||
go startDNSServer()
|
||
}
|
||
|
||
func httpHandler(w http.ResponseWriter, r *http.Request) {
|
||
// 实现HTTP代理逻辑
|
||
if r.Method == http.MethodConnect {
|
||
handleHTTPS(w, r)
|
||
} else {
|
||
handleHTTP(w, r)
|
||
}
|
||
}
|
||
|
||
func handleHTTP(w http.ResponseWriter, r *http.Request) {
|
||
// 用全局变量创建代理URL
|
||
proxyURL, err := url.Parse("http://" + proxyAddress)
|
||
if err != nil {
|
||
http.Error(w, "解析代理URL失败", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 创建一个使用HTTP代理的客户端
|
||
proxyClient := &http.Client{
|
||
Transport: &http.Transport{
|
||
Proxy: http.ProxyURL(proxyURL),
|
||
},
|
||
}
|
||
|
||
// 创建一个新的请求
|
||
proxyReq, err := http.NewRequest(r.Method, r.URL.String(), r.Body)
|
||
if err != nil {
|
||
http.Error(w, "创建代理请求失败", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 复制原始请求的头部
|
||
for name, values := range r.Header {
|
||
for _, value := range values {
|
||
proxyReq.Header.Add(name, value)
|
||
}
|
||
}
|
||
|
||
// 发送代理请求
|
||
resp, err := proxyClient.Do(proxyReq)
|
||
if err != nil {
|
||
http.Error(w, "代理请求失败", http.StatusBadGateway)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 复制响应头部
|
||
for name, values := range resp.Header {
|
||
for _, value := range values {
|
||
w.Header().Add(name, value)
|
||
}
|
||
}
|
||
|
||
// 设置状态码
|
||
w.WriteHeader(resp.StatusCode)
|
||
|
||
// 复制响应体
|
||
io.Copy(w, resp.Body)
|
||
|
||
// 设置Connection头为close,确保客户端知道连接将被关闭
|
||
w.Header().Set("Connection", "close")
|
||
}
|
||
|
||
func handleHTTPS(w http.ResponseWriter, r *http.Request) {
|
||
// 建立到目标服务器的连接
|
||
destConn, err := net.Dial("tcp", r.Host)
|
||
if err != nil {
|
||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||
return
|
||
}
|
||
defer destConn.Close()
|
||
|
||
// 向客户端发送200 OK响应
|
||
w.WriteHeader(http.StatusOK)
|
||
|
||
// 获取底层的TCP连接
|
||
hijacker, ok := w.(http.Hijacker)
|
||
if !ok {
|
||
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
clientConn, _, err := hijacker.Hijack()
|
||
if err != nil {
|
||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||
return
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
// 在客户端和目标服务器之间转发数据
|
||
go io.Copy(destConn, clientConn)
|
||
io.Copy(clientConn, destConn)
|
||
}
|
||
|
||
func fileExists(filename string) bool {
|
||
_, err := os.Stat(filename)
|
||
return !os.IsNotExist(err)
|
||
}
|
||
|
||
func generateCACert(caFile, keyFile string) {
|
||
// 生成CA私钥
|
||
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
if err != nil {
|
||
log.Fatal("生成CA私钥失败:", err)
|
||
}
|
||
|
||
// 创建CA证书模板
|
||
caTemplate := x509.Certificate{
|
||
SerialNumber: big.NewInt(1),
|
||
Subject: pkix.Name{
|
||
CommonName: "自签名CA",
|
||
},
|
||
NotBefore: time.Now(),
|
||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
|
||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||
BasicConstraintsValid: true,
|
||
IsCA: true,
|
||
}
|
||
|
||
// 创建CA证书
|
||
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey)
|
||
if err != nil {
|
||
log.Fatal("创建CA证书失败:", err)
|
||
}
|
||
|
||
// 确保SSL目录存在
|
||
sslDir := filepath.Dir(caFile)
|
||
if err := os.MkdirAll(sslDir, 0755); err != nil {
|
||
log.Fatal("创建SSL目录失败:", err)
|
||
}
|
||
|
||
// 将CA证书写入文件
|
||
caCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})
|
||
if err := ioutil.WriteFile(caFile, caCertPEM, 0644); err != nil {
|
||
log.Fatal("写入CA证书失败:", err)
|
||
}
|
||
|
||
// 将CA私钥写入文件
|
||
caKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(caKey)})
|
||
if err := ioutil.WriteFile(keyFile, caKeyPEM, 0600); err != nil {
|
||
log.Fatal("写入CA私钥失败:", err)
|
||
}
|
||
|
||
log.Println("生成了新的CA证书")
|
||
}
|
||
|
||
func loadCA(caFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) {
|
||
// 读取CA证书
|
||
caCertPEM, err := ioutil.ReadFile(caFile)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("读取CA证书失败: %v", err)
|
||
}
|
||
caCertBlock, _ := pem.Decode(caCertPEM)
|
||
if caCertBlock == nil {
|
||
return nil, nil, fmt.Errorf("解码CA证书失败")
|
||
}
|
||
caCert, err := x509.ParseCertificate(caCertBlock.Bytes)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("解析CA证书失败: %v", err)
|
||
}
|
||
|
||
// 读取CA私钥
|
||
caKeyPEM, err := ioutil.ReadFile(keyFile)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("读取CA私钥失败: %v", err)
|
||
}
|
||
caKeyBlock, _ := pem.Decode(caKeyPEM)
|
||
if caKeyBlock == nil {
|
||
return nil, nil, fmt.Errorf("解码CA私钥失败")
|
||
}
|
||
caKey, err := x509.ParsePKCS1PrivateKey(caKeyBlock.Bytes)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("解析CA私钥失败: %v", err)
|
||
}
|
||
|
||
return caCert, caKey, nil
|
||
}
|
||
|
||
func generateCert(caCert *x509.Certificate, caKey *rsa.PrivateKey, serverName string) (*tls.Certificate, error) {
|
||
// 生成服务器私钥
|
||
serverKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成服务器私钥失败: %v", err)
|
||
}
|
||
|
||
// 创建服务器证书模板
|
||
serverTemplate := x509.Certificate{
|
||
SerialNumber: big.NewInt(2),
|
||
Subject: pkix.Name{
|
||
CommonName: serverName,
|
||
},
|
||
NotBefore: time.Now(),
|
||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||
DNSNames: []string{serverName},
|
||
}
|
||
|
||
// 使用CA签发服务器证书
|
||
serverCertDER, err := x509.CreateCertificate(rand.Reader, &serverTemplate, caCert, &serverKey.PublicKey, caKey)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建服务器证书失败: %v", err)
|
||
}
|
||
|
||
// 将证书和私钥转换为PEM格式
|
||
serverCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: serverCertDER})
|
||
serverKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(serverKey)})
|
||
|
||
// 创建SSL文件夹和SNI子文件夹
|
||
sslDir := "SSL"
|
||
sniDir := filepath.Join(sslDir, serverName)
|
||
if err := os.MkdirAll(sniDir, 0755); err != nil {
|
||
return nil, fmt.Errorf("创建目录失败: %v", err)
|
||
}
|
||
|
||
// 保存证书和私钥到SNI子文件夹
|
||
certFile := filepath.Join(sniDir, serverName+".pem")
|
||
keyFile := filepath.Join(sniDir, serverName+"-key.pem")
|
||
if err := ioutil.WriteFile(certFile, serverCertPEM, 0644); err != nil {
|
||
return nil, fmt.Errorf("写入服务器证书失败: %v", err)
|
||
}
|
||
if err := ioutil.WriteFile(keyFile, serverKeyPEM, 0600); err != nil {
|
||
return nil, fmt.Errorf("写入服务器私钥失败: %v", err)
|
||
}
|
||
|
||
// 创tls.Certificate对象
|
||
cert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建X509密钥对失败: %v", err)
|
||
}
|
||
|
||
return &cert, nil
|
||
}
|
||
|
||
func httpsHandler(w http.ResponseWriter, r *http.Request) {
|
||
// 确保URL使用HTTPS协议并包含正确的Host
|
||
r.URL.Scheme = "https"
|
||
if r.URL.Host == "" {
|
||
r.URL.Host = r.Host
|
||
}
|
||
|
||
// 创建正向代理客户端
|
||
proxyClient := &http.Client{
|
||
Transport: &http.Transport{
|
||
Proxy: http.ProxyURL(&url.URL{
|
||
Scheme: "http",
|
||
Host: proxyAddress,
|
||
}),
|
||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||
DisableKeepAlives: true, // 禁用Keep-Alive
|
||
},
|
||
}
|
||
|
||
// 创建一个新的请求
|
||
proxyReq, err := http.NewRequest(r.Method, r.URL.String(), r.Body)
|
||
if err != nil {
|
||
http.Error(w, "创建代理请求失败", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
// 复制原始请求的头部
|
||
proxyReq.Header = r.Header.Clone()
|
||
|
||
// 确保请求体被关闭
|
||
if r.Body != nil {
|
||
defer r.Body.Close()
|
||
}
|
||
|
||
// 发送代理请求
|
||
resp, err := proxyClient.Do(proxyReq)
|
||
if err != nil {
|
||
http.Error(w, "代理请求失败: "+err.Error(), http.StatusBadGateway)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 复制响应头部和状态码
|
||
for name, values := range resp.Header {
|
||
w.Header()[name] = values
|
||
}
|
||
w.WriteHeader(resp.StatusCode)
|
||
|
||
// 复制响应体
|
||
_, err = io.Copy(w, resp.Body)
|
||
if err != nil {
|
||
log.Printf("复制响应体时发生错误: %v", err)
|
||
}
|
||
|
||
// 设置Connection头为close,确保客户端知道连接将被关闭
|
||
w.Header().Set("Connection", "close")
|
||
}
|
||
|
||
func startDNSServer() {
|
||
dns.HandleFunc(".", handleDNSRequest)
|
||
server := &dns.Server{Addr: ":" + dnsPort, Net: "udp"}
|
||
log.Printf("启动DNS服务器,监听端口: %s\n", dnsPort)
|
||
err := server.ListenAndServe()
|
||
if err != nil {
|
||
log.Fatalf("启动DNS服务器失败: %v", err)
|
||
}
|
||
}
|
||
|
||
func handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||
m := new(dns.Msg)
|
||
m.SetReply(r)
|
||
m.Authoritative = true
|
||
|
||
for _, q := range m.Question {
|
||
switch q.Qtype {
|
||
case dns.TypeA:
|
||
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, listenAddress))
|
||
if err == nil {
|
||
m.Answer = append(m.Answer, rr)
|
||
}
|
||
}
|
||
}
|
||
|
||
w.WriteMsg(m)
|
||
}
|