fix: coerce

This commit is contained in:
手瓜一十雪
2025-04-17 22:17:35 +08:00
parent 9df7c341a9
commit fb20b2e16c
7 changed files with 40 additions and 35 deletions

View File

@@ -1,14 +1,26 @@
import { z } from 'zod'; import { z } from 'zod';
const boolean = () => z.preprocess( const boolean = () => z.preprocess(
val => typeof val === 'string' && (val.toLowerCase() === 'false' || val === '0') ? false : Boolean(val), val => val === null || val === undefined
? val
: typeof val === 'string' && (val.toLowerCase() === 'false' || val === '0')
? false
: Boolean(val),
z.boolean() z.boolean()
); );
const number = () => z.preprocess( const number = () => z.preprocess(
val => typeof val !== 'number' ? Number(val) : val, val => val === null || val === undefined
? val
: typeof val !== 'number' ? Number(val) : val,
z.number() z.number()
); );
const string = () => z.preprocess( const string = () => z.preprocess(
val => typeof val !== 'string' ? String(val) : val, val => val === null || val === undefined
? val
: typeof val !== 'string' ? String(val) : val,
z.string() z.string()
); );
export const coerce = { boolean, number, string }; export const coerce = { boolean, number, string };

View File

@@ -32,7 +32,7 @@ export class OB11Response {
export abstract class OneBotAction<PayloadType, ReturnDataType> { export abstract class OneBotAction<PayloadType, ReturnDataType> {
actionName: typeof ActionName[keyof typeof ActionName] = ActionName.Unknown; actionName: typeof ActionName[keyof typeof ActionName] = ActionName.Unknown;
core: NapCatCore; core: NapCatCore;
payloadSchema?: z.ZodType<unknown> = undefined; payloadSchema?: z.ZodType<PayloadType, z.ZodTypeDef, unknown> = undefined;
obContext: NapCatOneBot11Adapter; obContext: NapCatOneBot11Adapter;
constructor(obContext: NapCatOneBot11Adapter, core: NapCatCore) { constructor(obContext: NapCatOneBot11Adapter, core: NapCatCore) {
@@ -40,15 +40,15 @@ export abstract class OneBotAction<PayloadType, ReturnDataType> {
this.core = core; this.core = core;
} }
protected async check(payload: unknown): Promise<BaseCheckResult> { protected async check(payload: unknown): Promise<BaseCheckResult & { parsedPayload?: PayloadType }> {
if (!this.payloadSchema) { if (!this.payloadSchema) {
return { valid: true }; return { valid: true, parsedPayload: payload as PayloadType };
} }
try { try {
// 使用 zod 验证并转换数据 // 使用 zod 验证并转换数据,并返回解析后的数据
this.payloadSchema.parse(payload); const parsedPayload = this.payloadSchema.parse(payload) as PayloadType;
return { valid: true }; return { valid: true, parsedPayload };
} catch (error) { } catch (error) {
if (error instanceof z.ZodError) { if (error instanceof z.ZodError) {
const errorMessages = error.errors.map(e => const errorMessages = error.errors.map(e =>
@@ -66,13 +66,13 @@ export abstract class OneBotAction<PayloadType, ReturnDataType> {
} }
} }
public async handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig): Promise<OB11Return<ReturnDataType | null>> { public async handle(payload: unknown, adaptername: string, config: NetworkAdapterConfig): Promise<OB11Return<ReturnDataType | null>> {
const result = await this.check(payload); const result = await this.check(payload);
if (!result.valid) { if (!result.valid) {
return OB11Response.error(result.message, 400); return OB11Response.error(result.message!, 400);
} }
try { try {
const resData = await this._handle(payload, adaptername, config); const resData = await this._handle(result.parsedPayload as PayloadType, adaptername, config);
return OB11Response.ok(resData); return OB11Response.ok(resData);
} catch (e: unknown) { } catch (e: unknown) {
this.core.context.logger.logError('发生错误', e); this.core.context.logger.logError('发生错误', e);
@@ -80,13 +80,13 @@ export abstract class OneBotAction<PayloadType, ReturnDataType> {
} }
} }
public async websocketHandle(payload: PayloadType, echo: unknown, adaptername: string, config: NetworkAdapterConfig): Promise<OB11Return<ReturnDataType | null>> { public async websocketHandle(payload: unknown, echo: unknown, adaptername: string, config: NetworkAdapterConfig): Promise<OB11Return<ReturnDataType | null>> {
const result = await this.check(payload); const result = await this.check(payload);
if (!result.valid) { if (!result.valid) {
return OB11Response.error(result.message, 1400, echo); return OB11Response.error(result.message!, 1400, echo);
} }
try { try {
const resData = await this._handle(payload, adaptername, config); const resData = await this._handle(result.parsedPayload as PayloadType, adaptername, config);
return OB11Response.ok(resData, echo); return OB11Response.ok(resData, echo);
} catch (e: unknown) { } catch (e: unknown) {
this.core.context.logger.logError('发生错误', e); this.core.context.logger.logError('发生错误', e);

View File

@@ -4,18 +4,13 @@ import { ActionName } from '@/onebot/action/router';
// 未验证 // 未验证
export class GoCQHTTPSendForwardMsgBase extends SendMsgBase { export class GoCQHTTPSendForwardMsgBase extends SendMsgBase {
protected override async check(payload: OB11PostSendMsg) { override async _handle(payload: OB11PostSendMsg) {
if (payload.messages) payload.message = normalize(payload.messages); if (payload.messages) payload.message = normalize(payload.messages);
return super.check(payload); return super._handle(payload);
} }
} }
export class GoCQHTTPSendForwardMsg extends GoCQHTTPSendForwardMsgBase { export class GoCQHTTPSendForwardMsg extends GoCQHTTPSendForwardMsgBase {
override actionName = ActionName.GoCQHTTP_SendForwardMsg; override actionName = ActionName.GoCQHTTP_SendForwardMsg;
protected override async check(payload: OB11PostSendMsg) {
if (payload.messages) payload.message = normalize(payload.messages);
return super.check(payload);
}
} }
export class GoCQHTTPSendPrivateForwardMsg extends GoCQHTTPSendForwardMsgBase { export class GoCQHTTPSendPrivateForwardMsg extends GoCQHTTPSendForwardMsgBase {
override actionName = ActionName.GoCQHTTP_SendPrivateForwardMsg; override actionName = ActionName.GoCQHTTP_SendPrivateForwardMsg;

View File

@@ -1,16 +1,15 @@
import { ContextMode, SendMsgBase } from '@/onebot/action/msg/SendMsg'; import { ContextMode, SendMsgBase } from '@/onebot/action/msg/SendMsg';
import { ActionName, BaseCheckResult } from '@/onebot/action/router'; import { ActionName } from '@/onebot/action/router';
import { OB11PostSendMsg } from '@/onebot/types'; import { OB11PostSendMsg } from '@/onebot/types';
// 未检测参数 // 未检测参数
class SendGroupMsg extends SendMsgBase { class SendGroupMsg extends SendMsgBase {
override actionName = ActionName.SendGroupMsg; override actionName = ActionName.SendGroupMsg;
override contextMode: ContextMode = ContextMode.Group; override contextMode: ContextMode = ContextMode.Group;
override async _handle(payload: OB11PostSendMsg) {
protected override async check(payload: OB11PostSendMsg): Promise<BaseCheckResult> {
delete payload.user_id; delete payload.user_id;
payload.message_type = 'group'; payload.message_type = 'group';
return super.check(payload); return super._handle(payload);
} }
} }

View File

@@ -91,7 +91,7 @@ function getSpecialMsgNum(payload: OB11PostSendMsg, msgType: OB11MessageDataType
export class SendMsgBase extends OneBotAction<OB11PostSendMsg, ReturnDataType> { export class SendMsgBase extends OneBotAction<OB11PostSendMsg, ReturnDataType> {
contextMode = ContextMode.Normal; contextMode = ContextMode.Normal;
protected override async check(payload: OB11PostSendMsg): Promise<BaseCheckResult> { protected override async check(payload: OB11PostSendMsg): Promise<BaseCheckResult & { parsedPayload?: OB11PostSendMsg }> {
const messages = normalize(payload.message); const messages = normalize(payload.message);
const nodeElementLength = getSpecialMsgNum(payload, OB11MessageDataType.node); const nodeElementLength = getSpecialMsgNum(payload, OB11MessageDataType.node);
if (nodeElementLength > 0 && nodeElementLength != messages.length) { if (nodeElementLength > 0 && nodeElementLength != messages.length) {
@@ -100,7 +100,7 @@ export class SendMsgBase extends OneBotAction<OB11PostSendMsg, ReturnDataType> {
message: '转发消息不能和普通消息混在一起发送,转发需要保证message只有type为node的元素', message: '转发消息不能和普通消息混在一起发送,转发需要保证message只有type为node的元素',
}; };
} }
return { valid: true }; return { valid: true , parsedPayload: payload };
} }
async _handle(payload: OB11PostSendMsg): Promise<ReturnDataType> { async _handle(payload: OB11PostSendMsg): Promise<ReturnDataType> {

View File

@@ -1,15 +1,14 @@
import { ContextMode, SendMsgBase } from './SendMsg'; import { ContextMode, SendMsgBase } from './SendMsg';
import { ActionName, BaseCheckResult } from '@/onebot/action/router'; import { ActionName } from '@/onebot/action/router';
import { OB11PostSendMsg } from '@/onebot/types'; import { OB11PostSendMsg } from '@/onebot/types';
// 未检测参数 // 未检测参数
class SendPrivateMsg extends SendMsgBase { class SendPrivateMsg extends SendMsgBase {
override actionName = ActionName.SendPrivateMsg; override actionName = ActionName.SendPrivateMsg;
override contextMode: ContextMode = ContextMode.Private; override contextMode: ContextMode = ContextMode.Private;
override async _handle(payload: OB11PostSendMsg) {
protected override async check(payload: OB11PostSendMsg): Promise<BaseCheckResult> { if (payload.messages) payload.message = payload.messages;
payload.message_type = 'private'; return super._handle(payload);
return super.check(payload);
} }
} }

View File

@@ -1,9 +1,9 @@
import { OneBotAction } from '@/onebot/action/OneBotAction'; import { OneBotAction } from '@/onebot/action/OneBotAction';
import { ActionName, BaseCheckResult } from '@/onebot/action/router'; import { ActionName } from '@/onebot/action/router';
export abstract class GetPacketStatusDepends<PT, RT> extends OneBotAction<PT, RT> { export abstract class GetPacketStatusDepends<PT, RT> extends OneBotAction<PT, RT> {
protected override async check(payload: PT): Promise<BaseCheckResult>{ protected override async check(payload: PT) {
if (!this.core.apis.PacketApi.available) { if (!this.core.apis.PacketApi.available) {
return { return {
valid: false, valid: false,