From 2792fa4776e7e438d195e46b4e446b2f3a2519fa Mon Sep 17 00:00:00 2001 From: idranme Date: Wed, 21 Aug 2024 00:14:15 +0800 Subject: [PATCH] fix --- src/common/server/websocket.ts | 93 --------- src/common/utils/file.ts | 211 ++++++++++----------- src/onebot11/server/ws/ReverseWebsocket.ts | 4 + src/onebot11/server/ws/WebsocketServer.ts | 115 ++++++++--- 4 files changed, 188 insertions(+), 235 deletions(-) delete mode 100644 src/common/server/websocket.ts diff --git a/src/common/server/websocket.ts b/src/common/server/websocket.ts deleted file mode 100644 index dc44980..0000000 --- a/src/common/server/websocket.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { WebSocket, WebSocketServer } from 'ws' -import urlParse from 'url' -import { IncomingMessage } from 'node:http' -import { log } from '../utils/log' -import { getConfigUtil } from '../config' -import { llonebotError } from '../data' - -class WebsocketClientBase { - private wsClient: WebSocket | undefined - - constructor() { } - - send(msg: string) { - if (this.wsClient && this.wsClient.readyState == WebSocket.OPEN) { - this.wsClient.send(msg) - } - } - - onMessage(msg: string) { } -} - -export class WebsocketServerBase { - private ws: WebSocketServer | null = null - - constructor() { - console.log(`llonebot websocket service started`) - } - - start(port: number) { - try { - this.ws = new WebSocketServer({ port, maxPayload: 1024 * 1024 * 1024 }) - llonebotError.wsServerError = '' - } catch (e: any) { - llonebotError.wsServerError = '正向ws服务启动失败, ' + e.toString() - } - this.ws?.on('connection', (wsClient, req) => { - const url = req.url?.split('?').shift() - this.authorize(wsClient, req) - this.onConnect(wsClient, url!, req) - wsClient.on('message', async (msg) => { - this.onMessage(wsClient, url!, msg.toString()) - }) - }) - } - - stop() { - llonebotError.wsServerError = '' - this.ws?.close((err) => { - log('ws server close failed!', err) - }) - this.ws = null - } - - restart(port: number) { - this.stop() - this.start(port) - } - - authorize(wsClient: WebSocket, req) { - let token = getConfigUtil().getConfig().token - const url = req.url.split('?').shift() - log('ws connect', url) - let clientToken: string = '' - const authHeader = req.headers['authorization'] - if (authHeader) { - clientToken = authHeader.split('Bearer ').pop() - log('receive ws header token', clientToken) - } else { - const parsedUrl = urlParse.parse(req.url, true) - const urlToken = parsedUrl.query.access_token - if (urlToken) { - if (Array.isArray(urlToken)) { - clientToken = urlToken[0] - } else { - clientToken = urlToken - } - log('receive ws url token', clientToken) - } - } - if (token && clientToken != token) { - this.authorizeFailed(wsClient) - return wsClient.close() - } - } - - authorizeFailed(wsClient: WebSocket) { } - - onConnect(wsClient: WebSocket, url: string, req: IncomingMessage) { } - - onMessage(wsClient: WebSocket, url: string, msg: string) { } - - sendHeart() { } -} diff --git a/src/common/utils/file.ts b/src/common/utils/file.ts index 7c349e9..752ebfa 100644 --- a/src/common/utils/file.ts +++ b/src/common/utils/file.ts @@ -1,9 +1,9 @@ import fs from 'node:fs' import fsPromise from 'node:fs/promises' import path from 'node:path' -import { log, TEMP_DIR } from './index' -import * as fileType from 'file-type' +import { TEMP_DIR } from './index' import { randomUUID, createHash } from 'node:crypto' +import { fileURLToPath } from 'node:url' export function isGIF(path: string) { const buffer = Buffer.alloc(4) @@ -32,31 +32,6 @@ export function checkFileReceived(path: string, timeout: number = 3000): Promise }) } -export async function file2base64(path: string) { - let result = { - err: '', - data: '', - } - try { - // 读取文件内容 - // if (!fs.existsSync(path)){ - // path = path.replace("\\Ori\\", "\\Thumb\\"); - // } - try { - await checkFileReceived(path, 5000) - } catch (e: any) { - result.err = e.toString() - return result - } - const data = await fsPromise.readFile(path) - // 转换为Base64编码 - result.data = data.toString('base64') - } catch (err: any) { - result.err = err.toString() - } - return result -} - export function calculateFileMD5(filePath: string): Promise { return new Promise((resolve, reject) => { // 创建一个流式读取器 @@ -109,112 +84,118 @@ export async function httpDownload(options: string | HttpDownloadOptions): Promi return Buffer.from(await fetchRes.arrayBuffer()) } +export enum FileUriType { + Unknown = 0, + FileURL = 1, + RemoteURL = 2, + OneBotBase64 = 3, + DataURL = 4, + Path = 5 +} + +export function checkUriType(uri: string): { type: FileUriType } { + if (uri.startsWith('base64://')) { + return { type: FileUriType.OneBotBase64 } + } + if (uri.startsWith('data:')) { + return { type: FileUriType.DataURL } + } + if (uri.startsWith('http://') || uri.startsWith('https://')) { + return { type: FileUriType.RemoteURL } + } + if (uri.startsWith('file://')) { + return { type: FileUriType.FileURL } + } + try { + if (fs.existsSync(uri)) return { type: FileUriType.Path } + } catch { } + return { type: FileUriType.Unknown } +} + +interface FetchFileRes { + data: Buffer + url: string +} + +async function fetchFile(url: string): Promise { + const headers: Record = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.71 Safari/537.36', + 'Host': new URL(url).hostname + } + const raw = await fetch(url, { headers }).catch((err) => { + if (err.cause) { + throw err.cause + } + throw err + }) + if (!raw.ok) throw new Error(`statusText: ${raw.statusText}`) + return { + data: Buffer.from(await raw.arrayBuffer()), + url: raw.url + } +} + type Uri2LocalRes = { success: boolean errMsg: string fileName: string - ext: string path: string isLocal: boolean } -export async function uri2local(uri: string, fileName: string | null = null): Promise { - let res = { - success: false, - errMsg: '', - fileName: '', - ext: '', - path: '', - isLocal: false, - } - if (!fileName) { - fileName = randomUUID() - } - let filePath = path.join(TEMP_DIR, fileName) - let url: URL | null = null - try { - url = new URL(uri) - } catch (e: any) { - res.errMsg = `uri ${uri} 解析失败,` + e.toString() + ` 可能${uri}不存在` - return res +export async function uri2local(uri: string, filename?: string): Promise { + const { type } = checkUriType(uri) + + if (type === FileUriType.FileURL) { + const filePath = fileURLToPath(uri) + const fileName = path.basename(filePath) + return { success: true, errMsg: '', fileName, path: filePath, isLocal: true } } - // log("uri protocol", url.protocol, uri); - if (url.protocol == 'base64:') { - // base64转成文件 - let base64Data = uri.split('base64://')[1] + if (type === FileUriType.Path) { + const fileName = path.basename(uri) + return { success: true, errMsg: '', fileName, path: uri, isLocal: true } + } + + if (type === FileUriType.RemoteURL) { try { - const buffer = Buffer.from(base64Data, 'base64') - await fsPromise.writeFile(filePath, buffer) - } catch (e: any) { - res.errMsg = `base64文件下载失败,` + e.toString() - return res - } - } else if (url.protocol == 'http:' || url.protocol == 'https:') { - // 下载文件 - let buffer: Buffer | null = null - try { - buffer = await httpDownload(uri) - } catch (e: any) { - res.errMsg = `${url}下载失败,` + e.toString() - return res - } - try { - const pathInfo = path.parse(decodeURIComponent(url.pathname)) - if (pathInfo.name) { - fileName = pathInfo.name - if (pathInfo.ext) { - fileName += pathInfo.ext - // res.ext = pathInfo.ext - } - } - fileName = fileName.replace(/[/\\:*?"<>|]/g, '_') - res.fileName = fileName - filePath = path.join(TEMP_DIR, randomUUID() + fileName) - await fsPromise.writeFile(filePath, buffer) - } catch (e: any) { - res.errMsg = `${url}下载失败,` + e.toString() - return res - } - } else { - let pathname: string - if (url.protocol === 'file:') { - // await fs.copyFile(url.pathname, filePath); - pathname = decodeURIComponent(url.pathname) - if (process.platform === 'win32') { - filePath = pathname.slice(1) + const res = await fetchFile(uri) + const match = res.url.match(/.+\/([^/?]*)(?=\?)?/) + if (match?.[1]) { + filename ??= match[1].replace(/[/\\:*?"<>|]/g, '_') } else { - filePath = pathname + filename ??= randomUUID() } + const filePath = path.join(TEMP_DIR, filename) + await fsPromise.writeFile(filePath, res.data) + return { success: true, errMsg: '', fileName: filename, path: filePath, isLocal: false } + } catch (e: any) { + const errMsg = `${uri}下载失败,` + e.toString() + return { success: false, errMsg, fileName: '', path: '', isLocal: false } } + } - res.isLocal = true + if (type === FileUriType.OneBotBase64) { + filename ??= randomUUID() + const filePath = path.join(TEMP_DIR, filename) + const base64 = uri.replace(/^base64:\/\//, '') + await fsPromise.writeFile(filePath, base64, 'base64') + return { success: true, errMsg: '', fileName: filename, path: filePath, isLocal: false } } - // else{ - // res.errMsg = `不支持的file协议,` + url.protocol - // return res - // } - // if (isGIF(filePath) && !res.isLocal) { - // await fs.rename(filePath, filePath + ".gif"); - // filePath += ".gif"; - // } - if (!res.isLocal && !res.ext) { - try { - const ext = (await fileType.fileTypeFromFile(filePath))?.ext - if (ext) { - log('获取文件类型', ext, filePath) - await fsPromise.rename(filePath, filePath + `.${ext}`) - filePath += `.${ext}` - res.fileName += `.${ext}` - res.ext = ext - } - } catch (e) { - // log("获取文件类型失败", filePath,e.stack) + + if (type === FileUriType.DataURL) { + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types + const capture = /^data:([\w/.+-]+);base64,(.*)$/.exec(uri) + if (capture) { + filename ??= randomUUID() + const [, _type, base64] = capture + const filePath = path.join(TEMP_DIR, filename) + await fsPromise.writeFile(filePath, base64, 'base64') + return { success: true, errMsg: '', fileName: filename, path: filePath, isLocal: false } } } - res.success = true - res.path = filePath - return res + + return { success: false, errMsg: '未知文件类型', fileName: '', path: '', isLocal: false } } export async function copyFolder(sourcePath: string, destPath: string) { diff --git a/src/onebot11/server/ws/ReverseWebsocket.ts b/src/onebot11/server/ws/ReverseWebsocket.ts index c7c0462..8272bf9 100644 --- a/src/onebot11/server/ws/ReverseWebsocket.ts +++ b/src/onebot11/server/ws/ReverseWebsocket.ts @@ -104,6 +104,10 @@ export class ReverseWebsocket { this.websocket.on('error', log) + this.websocket.on('ping',()=>{ + this.websocket?.pong() + }) + const wsClientInterval = setInterval(() => { postWsEvent(new OB11HeartbeatEvent(selfInfo.online!, true, heartInterval!)) }, heartInterval) // 心跳包 diff --git a/src/onebot11/server/ws/WebsocketServer.ts b/src/onebot11/server/ws/WebsocketServer.ts index 4cff47e..ee6560f 100644 --- a/src/onebot11/server/ws/WebsocketServer.ts +++ b/src/onebot11/server/ws/WebsocketServer.ts @@ -1,70 +1,131 @@ -import { WebSocket } from 'ws' +import BaseAction from '../../action/BaseAction' +import { WebSocket, WebSocketServer } from 'ws' import { actionMap } from '../../action' import { OB11Response } from '../../action/OB11Response' import { postWsEvent, registerWsEventSender, unregisterWsEventSender } from '../post-ob11-event' import { ActionName } from '../../action/types' -import BaseAction from '../../action/BaseAction' import { LifeCycleSubType, OB11LifeCycleEvent } from '../../event/meta/OB11LifeCycleEvent' import { OB11HeartbeatEvent } from '../../event/meta/OB11HeartbeatEvent' -import { WebsocketServerBase } from '../../../common/server/websocket' import { IncomingMessage } from 'node:http' import { wsReply } from './reply' -import { getSelfInfo } from '../../../common/data' -import { log } from '../../../common/utils/log' -import { getConfigUtil } from '../../../common/config' +import { getSelfInfo } from '@/common/data' +import { log } from '@/common/utils/log' +import { getConfigUtil } from '@/common/config' +import { llonebotError } from '@/common/data' -class OB11WebsocketServer extends WebsocketServerBase { - authorizeFailed(wsClient: WebSocket) { - wsClient.send(JSON.stringify(OB11Response.res(null, 'failed', 1403, 'token验证失败'))) +export class OB11WebsocketServer { + private ws?: WebSocketServer + + constructor() { + log(`llonebot websocket service started`) } - async handleAction(wsClient: WebSocket, actionName: string, params: any, echo?: any) { + start(port: number) { + try { + this.ws = new WebSocketServer({ port, maxPayload: 1024 * 1024 * 1024 }) + llonebotError.wsServerError = '' + } catch (e: any) { + llonebotError.wsServerError = '正向 WebSocket 服务启动失败, ' + e.toString() + return + } + this.ws?.on('connection', (socket, req) => { + const url = req.url?.split('?').shift() + this.authorize(socket, req) + this.onConnect(socket, url!) + }) + } + + stop() { + llonebotError.wsServerError = '' + this.ws?.close(err => { + log('ws server close failed!', err) + }) + this.ws = undefined + } + + restart(port: number) { + this.stop() + this.start(port) + } + + private authorize(socket: WebSocket, req: IncomingMessage) { + const { token } = getConfigUtil().getConfig() + const url = req.url?.split('?').shift() + log('ws connect', url) + let clientToken = '' + const authHeader = req.headers['authorization'] + if (authHeader) { + clientToken = authHeader.split('Bearer ').pop()! + log('receive ws header token', clientToken) + } else { + const { searchParams } = new URL(`http://localhost${req.url}`) + const urlToken = searchParams.get('access_token') + if (urlToken) { + if (Array.isArray(urlToken)) { + clientToken = urlToken[0] + } else { + clientToken = urlToken + } + log('receive ws url token', clientToken) + } + } + if (token && clientToken !== token) { + this.authorizeFailed(socket) + return socket.close() + } + } + + private authorizeFailed(socket: WebSocket) { + socket.send(JSON.stringify(OB11Response.res(null, 'failed', 1403, 'token验证失败'))) + } + + private async handleAction(socket: WebSocket, actionName: string, params: any, echo?: any) { const action: BaseAction = actionMap.get(actionName)! if (!action) { - return wsReply(wsClient, OB11Response.error('不支持的api ' + actionName, 1404, echo)) + return wsReply(socket, OB11Response.error('不支持的api ' + actionName, 1404, echo)) } try { - let handleResult = await action.websocketHandle(params, echo) + const handleResult = await action.websocketHandle(params, echo) handleResult.echo = echo - wsReply(wsClient, handleResult) + wsReply(socket, handleResult) } catch (e: any) { - wsReply(wsClient, OB11Response.error(`api处理出错:${e.stack}`, 1200, echo)) + wsReply(socket, OB11Response.error(`api处理出错:${e.stack}`, 1200, echo)) } } - onConnect(wsClient: WebSocket, url: string, req: IncomingMessage) { - if (url == '/api' || url == '/api/' || url == '/') { - wsClient.on('message', async (msg) => { + private onConnect(socket: WebSocket, url: string) { + if (['/api', '/api/', '/'].includes(url)) { + socket.on('message', async (msg) => { let receiveData: { action: ActionName | null; params: any; echo?: any } = { action: null, params: {} } - let echo = null + let echo: any try { receiveData = JSON.parse(msg.toString()) echo = receiveData.echo log('收到正向Websocket消息', receiveData) } catch (e) { - return wsReply(wsClient, OB11Response.error('json解析失败,请检查数据格式', 1400, echo)) + return wsReply(socket, OB11Response.error('json解析失败,请检查数据格式', 1400, echo)) } - this.handleAction(wsClient, receiveData.action!, receiveData.params, receiveData.echo).then() + this.handleAction(socket, receiveData.action!, receiveData.params, receiveData.echo) }) } - if (url == '/event' || url == '/event/' || url == '/') { - registerWsEventSender(wsClient) + if (['/event', '/event/', '/'].includes(url)) { + registerWsEventSender(socket) log('event上报ws客户端已连接') try { - wsReply(wsClient, new OB11LifeCycleEvent(LifeCycleSubType.CONNECT)) + wsReply(socket, new OB11LifeCycleEvent(LifeCycleSubType.CONNECT)) } catch (e) { log('发送生命周期失败', e) } const { heartInterval } = getConfigUtil().getConfig() - const wsClientInterval = setInterval(() => { + const intervalId = setInterval(() => { postWsEvent(new OB11HeartbeatEvent(getSelfInfo().online!, true, heartInterval!)) }, heartInterval) // 心跳包 - wsClient.on('close', () => { + socket.on('close', () => { log('event上报ws客户端已断开') - clearInterval(wsClientInterval) - unregisterWsEventSender(wsClient) + clearInterval(intervalId) + unregisterWsEventSender(socket) }) } }