mirror of
https://github.com/LLOneBot/LLOneBot.git
synced 2024-11-22 01:56:33 +00:00
fix
This commit is contained in:
parent
c37858e2f9
commit
2792fa4776
@ -1,93 +0,0 @@
|
||||
import { WebSocket, WebSocketServer } from 'ws'
|
||||
import urlParse from 'url'
|
||||
import { IncomingMessage } from 'node:http'
|
||||
import { log } from '../utils/log'
|
||||
import { getConfigUtil } from '../config'
|
||||
import { llonebotError } from '../data'
|
||||
|
||||
class WebsocketClientBase {
|
||||
private wsClient: WebSocket | undefined
|
||||
|
||||
constructor() { }
|
||||
|
||||
send(msg: string) {
|
||||
if (this.wsClient && this.wsClient.readyState == WebSocket.OPEN) {
|
||||
this.wsClient.send(msg)
|
||||
}
|
||||
}
|
||||
|
||||
onMessage(msg: string) { }
|
||||
}
|
||||
|
||||
export class WebsocketServerBase {
|
||||
private ws: WebSocketServer | null = null
|
||||
|
||||
constructor() {
|
||||
console.log(`llonebot websocket service started`)
|
||||
}
|
||||
|
||||
start(port: number) {
|
||||
try {
|
||||
this.ws = new WebSocketServer({ port, maxPayload: 1024 * 1024 * 1024 })
|
||||
llonebotError.wsServerError = ''
|
||||
} catch (e: any) {
|
||||
llonebotError.wsServerError = '正向ws服务启动失败, ' + e.toString()
|
||||
}
|
||||
this.ws?.on('connection', (wsClient, req) => {
|
||||
const url = req.url?.split('?').shift()
|
||||
this.authorize(wsClient, req)
|
||||
this.onConnect(wsClient, url!, req)
|
||||
wsClient.on('message', async (msg) => {
|
||||
this.onMessage(wsClient, url!, msg.toString())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
stop() {
|
||||
llonebotError.wsServerError = ''
|
||||
this.ws?.close((err) => {
|
||||
log('ws server close failed!', err)
|
||||
})
|
||||
this.ws = null
|
||||
}
|
||||
|
||||
restart(port: number) {
|
||||
this.stop()
|
||||
this.start(port)
|
||||
}
|
||||
|
||||
authorize(wsClient: WebSocket, req) {
|
||||
let token = getConfigUtil().getConfig().token
|
||||
const url = req.url.split('?').shift()
|
||||
log('ws connect', url)
|
||||
let clientToken: string = ''
|
||||
const authHeader = req.headers['authorization']
|
||||
if (authHeader) {
|
||||
clientToken = authHeader.split('Bearer ').pop()
|
||||
log('receive ws header token', clientToken)
|
||||
} else {
|
||||
const parsedUrl = urlParse.parse(req.url, true)
|
||||
const urlToken = parsedUrl.query.access_token
|
||||
if (urlToken) {
|
||||
if (Array.isArray(urlToken)) {
|
||||
clientToken = urlToken[0]
|
||||
} else {
|
||||
clientToken = urlToken
|
||||
}
|
||||
log('receive ws url token', clientToken)
|
||||
}
|
||||
}
|
||||
if (token && clientToken != token) {
|
||||
this.authorizeFailed(wsClient)
|
||||
return wsClient.close()
|
||||
}
|
||||
}
|
||||
|
||||
authorizeFailed(wsClient: WebSocket) { }
|
||||
|
||||
onConnect(wsClient: WebSocket, url: string, req: IncomingMessage) { }
|
||||
|
||||
onMessage(wsClient: WebSocket, url: string, msg: string) { }
|
||||
|
||||
sendHeart() { }
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
import fs from 'node:fs'
|
||||
import fsPromise from 'node:fs/promises'
|
||||
import path from 'node:path'
|
||||
import { log, TEMP_DIR } from './index'
|
||||
import * as fileType from 'file-type'
|
||||
import { TEMP_DIR } from './index'
|
||||
import { randomUUID, createHash } from 'node:crypto'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
|
||||
export function isGIF(path: string) {
|
||||
const buffer = Buffer.alloc(4)
|
||||
@ -32,31 +32,6 @@ export function checkFileReceived(path: string, timeout: number = 3000): Promise
|
||||
})
|
||||
}
|
||||
|
||||
export async function file2base64(path: string) {
|
||||
let result = {
|
||||
err: '',
|
||||
data: '',
|
||||
}
|
||||
try {
|
||||
// 读取文件内容
|
||||
// if (!fs.existsSync(path)){
|
||||
// path = path.replace("\\Ori\\", "\\Thumb\\");
|
||||
// }
|
||||
try {
|
||||
await checkFileReceived(path, 5000)
|
||||
} catch (e: any) {
|
||||
result.err = e.toString()
|
||||
return result
|
||||
}
|
||||
const data = await fsPromise.readFile(path)
|
||||
// 转换为Base64编码
|
||||
result.data = data.toString('base64')
|
||||
} catch (err: any) {
|
||||
result.err = err.toString()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
export function calculateFileMD5(filePath: string): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
// 创建一个流式读取器
|
||||
@ -109,112 +84,118 @@ export async function httpDownload(options: string | HttpDownloadOptions): Promi
|
||||
return Buffer.from(await fetchRes.arrayBuffer())
|
||||
}
|
||||
|
||||
export enum FileUriType {
|
||||
Unknown = 0,
|
||||
FileURL = 1,
|
||||
RemoteURL = 2,
|
||||
OneBotBase64 = 3,
|
||||
DataURL = 4,
|
||||
Path = 5
|
||||
}
|
||||
|
||||
export function checkUriType(uri: string): { type: FileUriType } {
|
||||
if (uri.startsWith('base64://')) {
|
||||
return { type: FileUriType.OneBotBase64 }
|
||||
}
|
||||
if (uri.startsWith('data:')) {
|
||||
return { type: FileUriType.DataURL }
|
||||
}
|
||||
if (uri.startsWith('http://') || uri.startsWith('https://')) {
|
||||
return { type: FileUriType.RemoteURL }
|
||||
}
|
||||
if (uri.startsWith('file://')) {
|
||||
return { type: FileUriType.FileURL }
|
||||
}
|
||||
try {
|
||||
if (fs.existsSync(uri)) return { type: FileUriType.Path }
|
||||
} catch { }
|
||||
return { type: FileUriType.Unknown }
|
||||
}
|
||||
|
||||
interface FetchFileRes {
|
||||
data: Buffer
|
||||
url: string
|
||||
}
|
||||
|
||||
async function fetchFile(url: string): Promise<FetchFileRes> {
|
||||
const headers: Record<string, string> = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.71 Safari/537.36',
|
||||
'Host': new URL(url).hostname
|
||||
}
|
||||
const raw = await fetch(url, { headers }).catch((err) => {
|
||||
if (err.cause) {
|
||||
throw err.cause
|
||||
}
|
||||
throw err
|
||||
})
|
||||
if (!raw.ok) throw new Error(`statusText: ${raw.statusText}`)
|
||||
return {
|
||||
data: Buffer.from(await raw.arrayBuffer()),
|
||||
url: raw.url
|
||||
}
|
||||
}
|
||||
|
||||
type Uri2LocalRes = {
|
||||
success: boolean
|
||||
errMsg: string
|
||||
fileName: string
|
||||
ext: string
|
||||
path: string
|
||||
isLocal: boolean
|
||||
}
|
||||
|
||||
export async function uri2local(uri: string, fileName: string | null = null): Promise<Uri2LocalRes> {
|
||||
let res = {
|
||||
success: false,
|
||||
errMsg: '',
|
||||
fileName: '',
|
||||
ext: '',
|
||||
path: '',
|
||||
isLocal: false,
|
||||
}
|
||||
if (!fileName) {
|
||||
fileName = randomUUID()
|
||||
}
|
||||
let filePath = path.join(TEMP_DIR, fileName)
|
||||
let url: URL | null = null
|
||||
try {
|
||||
url = new URL(uri)
|
||||
} catch (e: any) {
|
||||
res.errMsg = `uri ${uri} 解析失败,` + e.toString() + ` 可能${uri}不存在`
|
||||
return res
|
||||
export async function uri2local(uri: string, filename?: string): Promise<Uri2LocalRes> {
|
||||
const { type } = checkUriType(uri)
|
||||
|
||||
if (type === FileUriType.FileURL) {
|
||||
const filePath = fileURLToPath(uri)
|
||||
const fileName = path.basename(filePath)
|
||||
return { success: true, errMsg: '', fileName, path: filePath, isLocal: true }
|
||||
}
|
||||
|
||||
// log("uri protocol", url.protocol, uri);
|
||||
if (url.protocol == 'base64:') {
|
||||
// base64转成文件
|
||||
let base64Data = uri.split('base64://')[1]
|
||||
if (type === FileUriType.Path) {
|
||||
const fileName = path.basename(uri)
|
||||
return { success: true, errMsg: '', fileName, path: uri, isLocal: true }
|
||||
}
|
||||
|
||||
if (type === FileUriType.RemoteURL) {
|
||||
try {
|
||||
const buffer = Buffer.from(base64Data, 'base64')
|
||||
await fsPromise.writeFile(filePath, buffer)
|
||||
} catch (e: any) {
|
||||
res.errMsg = `base64文件下载失败,` + e.toString()
|
||||
return res
|
||||
}
|
||||
} else if (url.protocol == 'http:' || url.protocol == 'https:') {
|
||||
// 下载文件
|
||||
let buffer: Buffer | null = null
|
||||
try {
|
||||
buffer = await httpDownload(uri)
|
||||
} catch (e: any) {
|
||||
res.errMsg = `${url}下载失败,` + e.toString()
|
||||
return res
|
||||
}
|
||||
try {
|
||||
const pathInfo = path.parse(decodeURIComponent(url.pathname))
|
||||
if (pathInfo.name) {
|
||||
fileName = pathInfo.name
|
||||
if (pathInfo.ext) {
|
||||
fileName += pathInfo.ext
|
||||
// res.ext = pathInfo.ext
|
||||
}
|
||||
}
|
||||
fileName = fileName.replace(/[/\\:*?"<>|]/g, '_')
|
||||
res.fileName = fileName
|
||||
filePath = path.join(TEMP_DIR, randomUUID() + fileName)
|
||||
await fsPromise.writeFile(filePath, buffer)
|
||||
} catch (e: any) {
|
||||
res.errMsg = `${url}下载失败,` + e.toString()
|
||||
return res
|
||||
}
|
||||
} else {
|
||||
let pathname: string
|
||||
if (url.protocol === 'file:') {
|
||||
// await fs.copyFile(url.pathname, filePath);
|
||||
pathname = decodeURIComponent(url.pathname)
|
||||
if (process.platform === 'win32') {
|
||||
filePath = pathname.slice(1)
|
||||
const res = await fetchFile(uri)
|
||||
const match = res.url.match(/.+\/([^/?]*)(?=\?)?/)
|
||||
if (match?.[1]) {
|
||||
filename ??= match[1].replace(/[/\\:*?"<>|]/g, '_')
|
||||
} else {
|
||||
filePath = pathname
|
||||
filename ??= randomUUID()
|
||||
}
|
||||
const filePath = path.join(TEMP_DIR, filename)
|
||||
await fsPromise.writeFile(filePath, res.data)
|
||||
return { success: true, errMsg: '', fileName: filename, path: filePath, isLocal: false }
|
||||
} catch (e: any) {
|
||||
const errMsg = `${uri}下载失败,` + e.toString()
|
||||
return { success: false, errMsg, fileName: '', path: '', isLocal: false }
|
||||
}
|
||||
}
|
||||
|
||||
res.isLocal = true
|
||||
if (type === FileUriType.OneBotBase64) {
|
||||
filename ??= randomUUID()
|
||||
const filePath = path.join(TEMP_DIR, filename)
|
||||
const base64 = uri.replace(/^base64:\/\//, '')
|
||||
await fsPromise.writeFile(filePath, base64, 'base64')
|
||||
return { success: true, errMsg: '', fileName: filename, path: filePath, isLocal: false }
|
||||
}
|
||||
// else{
|
||||
// res.errMsg = `不支持的file协议,` + url.protocol
|
||||
// return res
|
||||
// }
|
||||
// if (isGIF(filePath) && !res.isLocal) {
|
||||
// await fs.rename(filePath, filePath + ".gif");
|
||||
// filePath += ".gif";
|
||||
// }
|
||||
if (!res.isLocal && !res.ext) {
|
||||
try {
|
||||
const ext = (await fileType.fileTypeFromFile(filePath))?.ext
|
||||
if (ext) {
|
||||
log('获取文件类型', ext, filePath)
|
||||
await fsPromise.rename(filePath, filePath + `.${ext}`)
|
||||
filePath += `.${ext}`
|
||||
res.fileName += `.${ext}`
|
||||
res.ext = ext
|
||||
}
|
||||
} catch (e) {
|
||||
// log("获取文件类型失败", filePath,e.stack)
|
||||
|
||||
if (type === FileUriType.DataURL) {
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types
|
||||
const capture = /^data:([\w/.+-]+);base64,(.*)$/.exec(uri)
|
||||
if (capture) {
|
||||
filename ??= randomUUID()
|
||||
const [, _type, base64] = capture
|
||||
const filePath = path.join(TEMP_DIR, filename)
|
||||
await fsPromise.writeFile(filePath, base64, 'base64')
|
||||
return { success: true, errMsg: '', fileName: filename, path: filePath, isLocal: false }
|
||||
}
|
||||
}
|
||||
res.success = true
|
||||
res.path = filePath
|
||||
return res
|
||||
|
||||
return { success: false, errMsg: '未知文件类型', fileName: '', path: '', isLocal: false }
|
||||
}
|
||||
|
||||
export async function copyFolder(sourcePath: string, destPath: string) {
|
||||
|
@ -104,6 +104,10 @@ export class ReverseWebsocket {
|
||||
|
||||
this.websocket.on('error', log)
|
||||
|
||||
this.websocket.on('ping',()=>{
|
||||
this.websocket?.pong()
|
||||
})
|
||||
|
||||
const wsClientInterval = setInterval(() => {
|
||||
postWsEvent(new OB11HeartbeatEvent(selfInfo.online!, true, heartInterval!))
|
||||
}, heartInterval) // 心跳包
|
||||
|
@ -1,70 +1,131 @@
|
||||
import { WebSocket } from 'ws'
|
||||
import BaseAction from '../../action/BaseAction'
|
||||
import { WebSocket, WebSocketServer } from 'ws'
|
||||
import { actionMap } from '../../action'
|
||||
import { OB11Response } from '../../action/OB11Response'
|
||||
import { postWsEvent, registerWsEventSender, unregisterWsEventSender } from '../post-ob11-event'
|
||||
import { ActionName } from '../../action/types'
|
||||
import BaseAction from '../../action/BaseAction'
|
||||
import { LifeCycleSubType, OB11LifeCycleEvent } from '../../event/meta/OB11LifeCycleEvent'
|
||||
import { OB11HeartbeatEvent } from '../../event/meta/OB11HeartbeatEvent'
|
||||
import { WebsocketServerBase } from '../../../common/server/websocket'
|
||||
import { IncomingMessage } from 'node:http'
|
||||
import { wsReply } from './reply'
|
||||
import { getSelfInfo } from '../../../common/data'
|
||||
import { log } from '../../../common/utils/log'
|
||||
import { getConfigUtil } from '../../../common/config'
|
||||
import { getSelfInfo } from '@/common/data'
|
||||
import { log } from '@/common/utils/log'
|
||||
import { getConfigUtil } from '@/common/config'
|
||||
import { llonebotError } from '@/common/data'
|
||||
|
||||
class OB11WebsocketServer extends WebsocketServerBase {
|
||||
authorizeFailed(wsClient: WebSocket) {
|
||||
wsClient.send(JSON.stringify(OB11Response.res(null, 'failed', 1403, 'token验证失败')))
|
||||
export class OB11WebsocketServer {
|
||||
private ws?: WebSocketServer
|
||||
|
||||
constructor() {
|
||||
log(`llonebot websocket service started`)
|
||||
}
|
||||
|
||||
async handleAction(wsClient: WebSocket, actionName: string, params: any, echo?: any) {
|
||||
start(port: number) {
|
||||
try {
|
||||
this.ws = new WebSocketServer({ port, maxPayload: 1024 * 1024 * 1024 })
|
||||
llonebotError.wsServerError = ''
|
||||
} catch (e: any) {
|
||||
llonebotError.wsServerError = '正向 WebSocket 服务启动失败, ' + e.toString()
|
||||
return
|
||||
}
|
||||
this.ws?.on('connection', (socket, req) => {
|
||||
const url = req.url?.split('?').shift()
|
||||
this.authorize(socket, req)
|
||||
this.onConnect(socket, url!)
|
||||
})
|
||||
}
|
||||
|
||||
stop() {
|
||||
llonebotError.wsServerError = ''
|
||||
this.ws?.close(err => {
|
||||
log('ws server close failed!', err)
|
||||
})
|
||||
this.ws = undefined
|
||||
}
|
||||
|
||||
restart(port: number) {
|
||||
this.stop()
|
||||
this.start(port)
|
||||
}
|
||||
|
||||
private authorize(socket: WebSocket, req: IncomingMessage) {
|
||||
const { token } = getConfigUtil().getConfig()
|
||||
const url = req.url?.split('?').shift()
|
||||
log('ws connect', url)
|
||||
let clientToken = ''
|
||||
const authHeader = req.headers['authorization']
|
||||
if (authHeader) {
|
||||
clientToken = authHeader.split('Bearer ').pop()!
|
||||
log('receive ws header token', clientToken)
|
||||
} else {
|
||||
const { searchParams } = new URL(`http://localhost${req.url}`)
|
||||
const urlToken = searchParams.get('access_token')
|
||||
if (urlToken) {
|
||||
if (Array.isArray(urlToken)) {
|
||||
clientToken = urlToken[0]
|
||||
} else {
|
||||
clientToken = urlToken
|
||||
}
|
||||
log('receive ws url token', clientToken)
|
||||
}
|
||||
}
|
||||
if (token && clientToken !== token) {
|
||||
this.authorizeFailed(socket)
|
||||
return socket.close()
|
||||
}
|
||||
}
|
||||
|
||||
private authorizeFailed(socket: WebSocket) {
|
||||
socket.send(JSON.stringify(OB11Response.res(null, 'failed', 1403, 'token验证失败')))
|
||||
}
|
||||
|
||||
private async handleAction(socket: WebSocket, actionName: string, params: any, echo?: any) {
|
||||
const action: BaseAction<any, any> = actionMap.get(actionName)!
|
||||
if (!action) {
|
||||
return wsReply(wsClient, OB11Response.error('不支持的api ' + actionName, 1404, echo))
|
||||
return wsReply(socket, OB11Response.error('不支持的api ' + actionName, 1404, echo))
|
||||
}
|
||||
try {
|
||||
let handleResult = await action.websocketHandle(params, echo)
|
||||
const handleResult = await action.websocketHandle(params, echo)
|
||||
handleResult.echo = echo
|
||||
wsReply(wsClient, handleResult)
|
||||
wsReply(socket, handleResult)
|
||||
} catch (e: any) {
|
||||
wsReply(wsClient, OB11Response.error(`api处理出错:${e.stack}`, 1200, echo))
|
||||
wsReply(socket, OB11Response.error(`api处理出错:${e.stack}`, 1200, echo))
|
||||
}
|
||||
}
|
||||
|
||||
onConnect(wsClient: WebSocket, url: string, req: IncomingMessage) {
|
||||
if (url == '/api' || url == '/api/' || url == '/') {
|
||||
wsClient.on('message', async (msg) => {
|
||||
private onConnect(socket: WebSocket, url: string) {
|
||||
if (['/api', '/api/', '/'].includes(url)) {
|
||||
socket.on('message', async (msg) => {
|
||||
let receiveData: { action: ActionName | null; params: any; echo?: any } = { action: null, params: {} }
|
||||
let echo = null
|
||||
let echo: any
|
||||
try {
|
||||
receiveData = JSON.parse(msg.toString())
|
||||
echo = receiveData.echo
|
||||
log('收到正向Websocket消息', receiveData)
|
||||
} catch (e) {
|
||||
return wsReply(wsClient, OB11Response.error('json解析失败,请检查数据格式', 1400, echo))
|
||||
return wsReply(socket, OB11Response.error('json解析失败,请检查数据格式', 1400, echo))
|
||||
}
|
||||
this.handleAction(wsClient, receiveData.action!, receiveData.params, receiveData.echo).then()
|
||||
this.handleAction(socket, receiveData.action!, receiveData.params, receiveData.echo)
|
||||
})
|
||||
}
|
||||
if (url == '/event' || url == '/event/' || url == '/') {
|
||||
registerWsEventSender(wsClient)
|
||||
if (['/event', '/event/', '/'].includes(url)) {
|
||||
registerWsEventSender(socket)
|
||||
|
||||
log('event上报ws客户端已连接')
|
||||
|
||||
try {
|
||||
wsReply(wsClient, new OB11LifeCycleEvent(LifeCycleSubType.CONNECT))
|
||||
wsReply(socket, new OB11LifeCycleEvent(LifeCycleSubType.CONNECT))
|
||||
} catch (e) {
|
||||
log('发送生命周期失败', e)
|
||||
}
|
||||
const { heartInterval } = getConfigUtil().getConfig()
|
||||
const wsClientInterval = setInterval(() => {
|
||||
const intervalId = setInterval(() => {
|
||||
postWsEvent(new OB11HeartbeatEvent(getSelfInfo().online!, true, heartInterval!))
|
||||
}, heartInterval) // 心跳包
|
||||
wsClient.on('close', () => {
|
||||
socket.on('close', () => {
|
||||
log('event上报ws客户端已断开')
|
||||
clearInterval(wsClientInterval)
|
||||
unregisterWsEventSender(wsClient)
|
||||
clearInterval(intervalId)
|
||||
unregisterWsEventSender(socket)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user