improve condition evaluate

This commit is contained in:
Yoan.liu 2025-05-20 22:54:41 +08:00
parent 97d692910b
commit faad7cb6d7
8 changed files with 440 additions and 45 deletions

View File

@ -26,18 +26,276 @@ const (
Not LogicalOperator = "not" Not LogicalOperator = "not"
) )
type EvalResult struct {
Type string
Value any
}
func (e *EvalResult) GetFloat64() (float64, error) {
if e.Type != "number" {
return 0, fmt.Errorf("type mismatch: %s", e.Type)
}
switch v := e.Value.(type) {
case int:
return float64(v), nil
case float64:
return v, nil
default:
return 0, fmt.Errorf("unsupported type: %T", v)
}
}
func (e *EvalResult) GreaterThan(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "number":
left, err := e.GetFloat64()
if err != nil {
return nil, err
}
right, err := other.GetFloat64()
if err != nil {
return nil, err
}
return &EvalResult{
Type: "boolean",
Value: left > right,
}, nil
case "string":
return &EvalResult{
Type: "boolean",
Value: e.Value.(string) > other.Value.(string),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
func (e *EvalResult) GreaterOrEqual(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "number":
left, err := e.GetFloat64()
if err != nil {
return nil, err
}
right, err := other.GetFloat64()
if err != nil {
return nil, err
}
return &EvalResult{
Type: "boolean",
Value: left >= right,
}, nil
case "string":
return &EvalResult{
Type: "boolean",
Value: e.Value.(string) >= other.Value.(string),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
func (e *EvalResult) LessThan(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "number":
left, err := e.GetFloat64()
if err != nil {
return nil, err
}
right, err := other.GetFloat64()
if err != nil {
return nil, err
}
return &EvalResult{
Type: "boolean",
Value: left < right,
}, nil
case "string":
return &EvalResult{
Type: "boolean",
Value: e.Value.(string) < other.Value.(string),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
func (e *EvalResult) LessOrEqual(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "number":
left, err := e.GetFloat64()
if err != nil {
return nil, err
}
right, err := other.GetFloat64()
if err != nil {
return nil, err
}
return &EvalResult{
Type: "boolean",
Value: left <= right,
}, nil
case "string":
return &EvalResult{
Type: "boolean",
Value: e.Value.(string) <= other.Value.(string),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
func (e *EvalResult) Equal(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "number":
left, err := e.GetFloat64()
if err != nil {
return nil, err
}
right, err := other.GetFloat64()
if err != nil {
return nil, err
}
return &EvalResult{
Type: "boolean",
Value: left == right,
}, nil
case "string":
return &EvalResult{
Type: "boolean",
Value: e.Value.(string) == other.Value.(string),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
func (e *EvalResult) NotEqual(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "number":
left, err := e.GetFloat64()
if err != nil {
return nil, err
}
right, err := other.GetFloat64()
if err != nil {
return nil, err
}
return &EvalResult{
Type: "boolean",
Value: left != right,
}, nil
case "string":
return &EvalResult{
Type: "boolean",
Value: e.Value.(string) != other.Value.(string),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
func (e *EvalResult) And(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "boolean":
return &EvalResult{
Type: "boolean",
Value: e.Value.(bool) && other.Value.(bool),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
func (e *EvalResult) Or(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "boolean":
return &EvalResult{
Type: "boolean",
Value: e.Value.(bool) || other.Value.(bool),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
func (e *EvalResult) Not() (*EvalResult, error) {
if e.Type != "boolean" {
return nil, fmt.Errorf("type mismatch: %s", e.Type)
}
return &EvalResult{
Type: "boolean",
Value: !e.Value.(bool),
}, nil
}
func (e *EvalResult) Is(other *EvalResult) (*EvalResult, error) {
if e.Type != other.Type {
return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type)
}
switch e.Type {
case "boolean":
return &EvalResult{
Type: "boolean",
Value: e.Value.(bool) == other.Value.(bool),
}, nil
default:
return nil, fmt.Errorf("unsupported type: %s", e.Type)
}
}
type Expr interface { type Expr interface {
GetType() string GetType() string
Eval(variables map[string]map[string]any) (any, error) Eval(variables map[string]map[string]any) (*EvalResult, error)
} }
type ConstExpr struct { type ConstExpr struct {
Type string `json:"type"` Type string `json:"type"`
Value Value `json:"value"` Value Value `json:"value"`
ValueType string `json:"valueType"`
} }
func (c ConstExpr) GetType() string { return c.Type } func (c ConstExpr) GetType() string { return c.Type }
func (c ConstExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
return &EvalResult{
Type: c.ValueType,
Value: c.Value,
}, nil
}
type VarExpr struct { type VarExpr struct {
Type string `json:"type"` Type string `json:"type"`
Selector WorkflowNodeIOValueSelector `json:"selector"` Selector WorkflowNodeIOValueSelector `json:"selector"`
@ -45,7 +303,7 @@ type VarExpr struct {
func (v VarExpr) GetType() string { return v.Type } func (v VarExpr) GetType() string { return v.Type }
func (v VarExpr) Eval(variables map[string]map[string]any) (any, error) { func (v VarExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
if v.Selector.Id == "" { if v.Selector.Id == "" {
return nil, fmt.Errorf("node id is empty") return nil, fmt.Errorf("node id is empty")
} }
@ -58,10 +316,12 @@ func (v VarExpr) Eval(variables map[string]map[string]any) (any, error) {
} }
if _, ok := variables[v.Selector.Id][v.Selector.Name]; !ok { if _, ok := variables[v.Selector.Id][v.Selector.Name]; !ok {
return nil, fmt.Errorf("variable %s not found in node %s", v.Selector.Name, v.Selector.NodeId) return nil, fmt.Errorf("variable %s not found in node %s", v.Selector.Name, v.Selector.Id)
} }
return &EvalResult{
return variables[v.Selector.Id][v.Selector.Name], nil Type: v.Selector.Type,
Value: variables[v.Selector.Id][v.Selector.Name],
}, nil
} }
type CompareExpr struct { type CompareExpr struct {
@ -73,7 +333,7 @@ type CompareExpr struct {
func (c CompareExpr) GetType() string { return c.Type } func (c CompareExpr) GetType() string { return c.Type }
func (c CompareExpr) Eval(variables map[string]map[string]any) (any, error) { func (c CompareExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
left, err := c.Left.Eval(variables) left, err := c.Left.Eval(variables)
if err != nil { if err != nil {
return nil, err return nil, err
@ -85,19 +345,19 @@ func (c CompareExpr) Eval(variables map[string]map[string]any) (any, error) {
switch c.Op { switch c.Op {
case GreaterThan: case GreaterThan:
return left.(float64) > right.(float64), nil return left.GreaterThan(right)
case LessThan: case LessThan:
return left.(float64) < right.(float64), nil return left.LessThan(right)
case GreaterOrEqual: case GreaterOrEqual:
return left.(float64) >= right.(float64), nil return left.GreaterOrEqual(right)
case LessOrEqual: case LessOrEqual:
return left.(float64) <= right.(float64), nil return left.LessOrEqual(right)
case Equal: case Equal:
return left == right, nil return left.Equal(right)
case NotEqual: case NotEqual:
return left != right, nil return left.NotEqual(right)
case Is: case Is:
return left == right, nil return left.Is(right)
default: default:
return nil, fmt.Errorf("unknown operator: %s", c.Op) return nil, fmt.Errorf("unknown operator: %s", c.Op)
} }
@ -112,7 +372,7 @@ type LogicalExpr struct {
func (l LogicalExpr) GetType() string { return l.Type } func (l LogicalExpr) GetType() string { return l.Type }
func (l LogicalExpr) Eval(variables map[string]map[string]any) (any, error) { func (l LogicalExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
left, err := l.Left.Eval(variables) left, err := l.Left.Eval(variables)
if err != nil { if err != nil {
return nil, err return nil, err
@ -124,9 +384,9 @@ func (l LogicalExpr) Eval(variables map[string]map[string]any) (any, error) {
switch l.Op { switch l.Op {
case And: case And:
return left.(bool) && right.(bool), nil return left.And(right)
case Or: case Or:
return left.(bool) || right.(bool), nil return left.Or(right)
default: default:
return nil, fmt.Errorf("unknown operator: %s", l.Op) return nil, fmt.Errorf("unknown operator: %s", l.Op)
} }
@ -139,12 +399,12 @@ type NotExpr struct {
func (n NotExpr) GetType() string { return n.Type } func (n NotExpr) GetType() string { return n.Type }
func (n NotExpr) Eval(variables map[string]map[string]any) (any, error) { func (n NotExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) {
inner, err := n.Expr.Eval(variables) inner, err := n.Expr.Eval(variables)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return !inner.(bool), nil return inner.Not()
} }
type rawExpr struct { type rawExpr struct {

View File

@ -0,0 +1,127 @@
package domain
import (
"testing"
)
func TestLogicalEval(t *testing.T) {
// 测试逻辑表达式 and
logicalExpr := LogicalExpr{
Left: ConstExpr{
Type: "const",
Value: true,
ValueType: "boolean",
},
Op: And,
Right: ConstExpr{
Type: "const",
Value: true,
ValueType: "boolean",
},
}
result, err := logicalExpr.Eval(nil)
if err != nil {
t.Errorf("failed to evaluate logical expression: %v", err)
}
if result.Value != true {
t.Errorf("expected true, got %v", result)
}
// 测试逻辑表达式 or
orExpr := LogicalExpr{
Left: ConstExpr{
Type: "const",
Value: true,
ValueType: "boolean",
},
Op: Or,
Right: ConstExpr{
Type: "const",
Value: true,
ValueType: "boolean",
},
}
result, err = orExpr.Eval(nil)
if err != nil {
t.Errorf("failed to evaluate logical expression: %v", err)
}
if result.Value != true {
t.Errorf("expected true, got %v", result)
}
}
func TestUnmarshalExpr(t *testing.T) {
type args struct {
data []byte
}
tests := []struct {
name string
args args
want Expr
wantErr bool
}{
{
name: "test1",
args: args{
data: []byte(`{"left":{"left":{"selector":{"id":"ODnYSOXB6HQP2_vz6JcZE","name":"certificate.validated","type":"boolean"},"type":"var"},"op":"is","right":{"type":"const","value":true,"valueType":"boolean"},"type":"compare"},"op":"and","right":{"left":{"selector":{"id":"ODnYSOXB6HQP2_vz6JcZE","name":"certificate.daysLeft","type":"number"},"type":"var"},"op":"==","right":{"type":"const","value":2,"valueType":"number"},"type":"compare"},"type":"logical"}`),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := UnmarshalExpr(tt.args.data)
if (err != nil) != tt.wantErr {
t.Errorf("UnmarshalExpr() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got == nil {
t.Errorf("UnmarshalExpr() got = nil, want %v", tt.want)
return
}
})
}
}
func TestExpr_Eval(t *testing.T) {
type args struct {
variables map[string]map[string]any
data []byte
}
tests := []struct {
name string
args args
want *EvalResult
wantErr bool
}{
{
name: "test1",
args: args{
variables: map[string]map[string]any{
"ODnYSOXB6HQP2_vz6JcZE": {
"certificate.validated": true,
"certificate.daysLeft": 2,
},
},
data: []byte(`{"left":{"left":{"selector":{"id":"ODnYSOXB6HQP2_vz6JcZE","name":"certificate.validated","type":"boolean"},"type":"var"},"op":"is","right":{"type":"const","value":true,"valueType":"boolean"},"type":"compare"},"op":"and","right":{"left":{"selector":{"id":"ODnYSOXB6HQP2_vz6JcZE","name":"certificate.daysLeft","type":"number"},"type":"var"},"op":"==","right":{"type":"const","value":2,"valueType":"number"},"type":"compare"},"type":"logical"}`),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := UnmarshalExpr(tt.args.data)
if err != nil {
t.Errorf("UnmarshalExpr() error = %v", err)
return
}
got, err := c.Eval(tt.args.variables)
t.Log("got:", got)
if (err != nil) != tt.wantErr {
t.Errorf("ConstExpr.Eval() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got.Value != true {
t.Errorf("ConstExpr.Eval() got = %v, want %v", got.Value, true)
}
})
}
}

View File

@ -1,6 +1,7 @@
package domain package domain
import ( import (
"encoding/json"
"time" "time"
maputil "github.com/usual2970/certimate/internal/pkg/utils/map" maputil "github.com/usual2970/certimate/internal/pkg/utils/map"
@ -109,11 +110,13 @@ type WorkflowNodeConfigForNotify struct {
} }
func (n *WorkflowNode) GetConfigForCondition() WorkflowNodeConfigForCondition { func (n *WorkflowNode) GetConfigForCondition() WorkflowNodeConfigForCondition {
raw := maputil.GetString(n.Config, "expression") expression := n.Config["expression"]
if raw == "" { if expression == nil {
return WorkflowNodeConfigForCondition{} return WorkflowNodeConfigForCondition{}
} }
raw, _ := json.Marshal(expression)
expr, err := UnmarshalExpr([]byte(raw)) expr, err := UnmarshalExpr([]byte(raw))
if err != nil { if err != nil {
return WorkflowNodeConfigForCondition{} return WorkflowNodeConfigForCondition{}

View File

@ -98,7 +98,9 @@ func (w *workflowInvoker) processNode(ctx context.Context, node *domain.Workflow
procErr = processor.Process(ctx) procErr = processor.Process(ctx)
if procErr != nil { if procErr != nil {
if current.Type != domain.WorkflowNodeTypeCondition {
processor.GetLogger().Error(procErr.Error()) processor.GetLogger().Error(procErr.Error())
}
break break
} }
@ -110,9 +112,12 @@ func (w *workflowInvoker) processNode(ctx context.Context, node *domain.Workflow
break break
} }
// TODO: 优化可读性 // TODO: 优化可读性
if procErr != nil && current.Next != nil && current.Next.Type != domain.WorkflowNodeTypeExecuteResultBranch { if procErr != nil && current.Type == domain.WorkflowNodeTypeCondition {
current = nil
procErr = nil
return nil
} else if procErr != nil && current.Next != nil && current.Next.Type != domain.WorkflowNodeTypeExecuteResultBranch {
return procErr return procErr
} else if procErr != nil && current.Next != nil && current.Next.Type == domain.WorkflowNodeTypeExecuteResultBranch { } else if procErr != nil && current.Next != nil && current.Next.Type == domain.WorkflowNodeTypeExecuteResultBranch {
current = w.getBranchByType(current.Next.Branches, domain.WorkflowNodeTypeExecuteFailure) current = w.getBranchByType(current.Next.Branches, domain.WorkflowNodeTypeExecuteFailure)

View File

@ -26,27 +26,26 @@ func (n *conditionNode) Process(ctx context.Context) error {
nodeConfig := n.node.GetConfigForCondition() nodeConfig := n.node.GetConfigForCondition()
if nodeConfig.Expression == nil { if nodeConfig.Expression == nil {
n.logger.Info("no condition found, continue to next node")
return nil return nil
} }
rs, err := n.eval(ctx, nodeConfig.Expression)
if err != nil {
n.logger.Warn("failed to eval expression: " + err.Error())
return err
}
if rs.Value == false {
n.logger.Info("condition not met, skip this branch")
return errors.New("condition not met")
}
n.logger.Info("condition met, continue to next node")
return nil return nil
} }
func (n *conditionNode) eval(ctx context.Context, expression domain.Expr) (any, error) { func (n *conditionNode) eval(ctx context.Context, expression domain.Expr) (*domain.EvalResult, error) {
switch expr:=expression.(type) { variables := GetNodeOutputs(ctx)
case domain.CompareExpr: return expression.Eval(variables)
left,err:= n.eval(ctx, expr.Left)
if err != nil {
return nil, err
}
right,err:= n.eval(ctx, expr.Right)
if err != nil {
return nil, err
}
case domain.LogicalExpr:
case domain.NotExpr:
case domain.VarExpr:
case domain.ConstExpr:
}
return false, errors.New("unknown expression type")
} }

View File

@ -57,7 +57,7 @@ const ConditionNode = ({ node, disabled, branchId, branchIndex }: ConditionNodeP
break; break;
} }
const right: Expr = { type: "const", value: value }; const right: Expr = { type: "const", value: value, valueType: t };
return { return {
type: "compare", type: "compare",

View File

@ -318,6 +318,7 @@ const formToExpression = (values: ConditionNodeConfigFormFieldValues): Expr => {
const right: Expr = { const right: Expr = {
type: "const", type: "const",
value: rightValue, value: rightValue,
valueType: type,
}; };
return { return {

View File

@ -238,7 +238,7 @@ export type ComparisonOperator = ">" | "<" | ">=" | "<=" | "==" | "!=" | "is";
export type LogicalOperator = "and" | "or" | "not"; export type LogicalOperator = "and" | "or" | "not";
export type ConstExpr = { type: "const"; value: Value }; export type ConstExpr = { type: "const"; value: Value; valueType: WorkflowNodeIoValueType };
export type VarExpr = { type: "var"; selector: WorkflowNodeIOValueSelector }; export type VarExpr = { type: "var"; selector: WorkflowNodeIOValueSelector };
export type CompareExpr = { type: "compare"; op: ComparisonOperator; left: Expr; right: Expr }; export type CompareExpr = { type: "compare"; op: ComparisonOperator; left: Expr; right: Expr };
export type LogicalExpr = { type: "logical"; op: LogicalOperator; left: Expr; right: Expr }; export type LogicalExpr = { type: "logical"; op: LogicalOperator; left: Expr; right: Expr };