From faad7cb6d791ff59a2def22bdefb9c2422d62f49 Mon Sep 17 00:00:00 2001 From: "Yoan.liu" Date: Tue, 20 May 2025 22:54:41 +0800 Subject: [PATCH] improve condition evaluate --- internal/domain/expr.go | 300 ++++++++++++++++-- internal/domain/expr_test.go | 127 ++++++++ internal/domain/workflow.go | 7 +- internal/workflow/dispatcher/invoker.go | 11 +- .../workflow/node-processor/condition_node.go | 35 +- .../workflow/node/ConditionNode.tsx | 2 +- .../workflow/node/ConditionNodeConfigForm.tsx | 1 + ui/src/domain/workflow.ts | 2 +- 8 files changed, 440 insertions(+), 45 deletions(-) create mode 100644 internal/domain/expr_test.go diff --git a/internal/domain/expr.go b/internal/domain/expr.go index 3b312642..4791ba7d 100644 --- a/internal/domain/expr.go +++ b/internal/domain/expr.go @@ -26,18 +26,276 @@ const ( Not LogicalOperator = "not" ) +type EvalResult struct { + Type string + Value any +} + +func (e *EvalResult) GetFloat64() (float64, error) { + if e.Type != "number" { + return 0, fmt.Errorf("type mismatch: %s", e.Type) + } + switch v := e.Value.(type) { + case int: + return float64(v), nil + case float64: + return v, nil + default: + return 0, fmt.Errorf("unsupported type: %T", v) + } +} + +func (e *EvalResult) GreaterThan(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "number": + + left, err := e.GetFloat64() + if err != nil { + return nil, err + } + right, err := other.GetFloat64() + if err != nil { + return nil, err + } + + return &EvalResult{ + Type: "boolean", + Value: left > right, + }, nil + case "string": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(string) > other.Value.(string), + }, nil + + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + +func (e *EvalResult) GreaterOrEqual(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "number": + left, err := e.GetFloat64() + if err != nil { + return nil, err + } + right, err := other.GetFloat64() + if err != nil { + return nil, err + } + return &EvalResult{ + Type: "boolean", + Value: left >= right, + }, nil + case "string": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(string) >= other.Value.(string), + }, nil + + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + +func (e *EvalResult) LessThan(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "number": + left, err := e.GetFloat64() + if err != nil { + return nil, err + } + right, err := other.GetFloat64() + if err != nil { + return nil, err + } + return &EvalResult{ + Type: "boolean", + Value: left < right, + }, nil + case "string": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(string) < other.Value.(string), + }, nil + + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + +func (e *EvalResult) LessOrEqual(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "number": + left, err := e.GetFloat64() + if err != nil { + return nil, err + } + right, err := other.GetFloat64() + if err != nil { + return nil, err + } + return &EvalResult{ + Type: "boolean", + Value: left <= right, + }, nil + case "string": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(string) <= other.Value.(string), + }, nil + + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + +func (e *EvalResult) Equal(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "number": + left, err := e.GetFloat64() + if err != nil { + return nil, err + } + right, err := other.GetFloat64() + if err != nil { + return nil, err + } + return &EvalResult{ + Type: "boolean", + Value: left == right, + }, nil + case "string": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(string) == other.Value.(string), + }, nil + + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + +func (e *EvalResult) NotEqual(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "number": + left, err := e.GetFloat64() + if err != nil { + return nil, err + } + right, err := other.GetFloat64() + if err != nil { + return nil, err + } + return &EvalResult{ + Type: "boolean", + Value: left != right, + }, nil + case "string": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(string) != other.Value.(string), + }, nil + + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + +func (e *EvalResult) And(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "boolean": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(bool) && other.Value.(bool), + }, nil + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + +func (e *EvalResult) Or(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "boolean": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(bool) || other.Value.(bool), + }, nil + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + +func (e *EvalResult) Not() (*EvalResult, error) { + if e.Type != "boolean" { + return nil, fmt.Errorf("type mismatch: %s", e.Type) + } + return &EvalResult{ + Type: "boolean", + Value: !e.Value.(bool), + }, nil +} + +func (e *EvalResult) Is(other *EvalResult) (*EvalResult, error) { + if e.Type != other.Type { + return nil, fmt.Errorf("type mismatch: %s vs %s", e.Type, other.Type) + } + switch e.Type { + case "boolean": + return &EvalResult{ + Type: "boolean", + Value: e.Value.(bool) == other.Value.(bool), + }, nil + default: + return nil, fmt.Errorf("unsupported type: %s", e.Type) + } +} + type Expr interface { GetType() string - Eval(variables map[string]map[string]any) (any, error) + Eval(variables map[string]map[string]any) (*EvalResult, error) } type ConstExpr struct { - Type string `json:"type"` - Value Value `json:"value"` + Type string `json:"type"` + Value Value `json:"value"` + ValueType string `json:"valueType"` } func (c ConstExpr) GetType() string { return c.Type } +func (c ConstExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) { + return &EvalResult{ + Type: c.ValueType, + Value: c.Value, + }, nil +} + type VarExpr struct { Type string `json:"type"` Selector WorkflowNodeIOValueSelector `json:"selector"` @@ -45,7 +303,7 @@ type VarExpr struct { func (v VarExpr) GetType() string { return v.Type } -func (v VarExpr) Eval(variables map[string]map[string]any) (any, error) { +func (v VarExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) { if v.Selector.Id == "" { return nil, fmt.Errorf("node id is empty") } @@ -58,10 +316,12 @@ func (v VarExpr) Eval(variables map[string]map[string]any) (any, error) { } if _, ok := variables[v.Selector.Id][v.Selector.Name]; !ok { - return nil, fmt.Errorf("variable %s not found in node %s", v.Selector.Name, v.Selector.NodeId) + return nil, fmt.Errorf("variable %s not found in node %s", v.Selector.Name, v.Selector.Id) } - - return variables[v.Selector.Id][v.Selector.Name], nil + return &EvalResult{ + Type: v.Selector.Type, + Value: variables[v.Selector.Id][v.Selector.Name], + }, nil } type CompareExpr struct { @@ -73,7 +333,7 @@ type CompareExpr struct { func (c CompareExpr) GetType() string { return c.Type } -func (c CompareExpr) Eval(variables map[string]map[string]any) (any, error) { +func (c CompareExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) { left, err := c.Left.Eval(variables) if err != nil { return nil, err @@ -85,19 +345,19 @@ func (c CompareExpr) Eval(variables map[string]map[string]any) (any, error) { switch c.Op { case GreaterThan: - return left.(float64) > right.(float64), nil + return left.GreaterThan(right) case LessThan: - return left.(float64) < right.(float64), nil + return left.LessThan(right) case GreaterOrEqual: - return left.(float64) >= right.(float64), nil + return left.GreaterOrEqual(right) case LessOrEqual: - return left.(float64) <= right.(float64), nil + return left.LessOrEqual(right) case Equal: - return left == right, nil + return left.Equal(right) case NotEqual: - return left != right, nil + return left.NotEqual(right) case Is: - return left == right, nil + return left.Is(right) default: return nil, fmt.Errorf("unknown operator: %s", c.Op) } @@ -112,7 +372,7 @@ type LogicalExpr struct { func (l LogicalExpr) GetType() string { return l.Type } -func (l LogicalExpr) Eval(variables map[string]map[string]any) (any, error) { +func (l LogicalExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) { left, err := l.Left.Eval(variables) if err != nil { return nil, err @@ -124,9 +384,9 @@ func (l LogicalExpr) Eval(variables map[string]map[string]any) (any, error) { switch l.Op { case And: - return left.(bool) && right.(bool), nil + return left.And(right) case Or: - return left.(bool) || right.(bool), nil + return left.Or(right) default: return nil, fmt.Errorf("unknown operator: %s", l.Op) } @@ -139,12 +399,12 @@ type NotExpr struct { func (n NotExpr) GetType() string { return n.Type } -func (n NotExpr) Eval(variables map[string]map[string]any) (any, error) { +func (n NotExpr) Eval(variables map[string]map[string]any) (*EvalResult, error) { inner, err := n.Expr.Eval(variables) if err != nil { return nil, err } - return !inner.(bool), nil + return inner.Not() } type rawExpr struct { diff --git a/internal/domain/expr_test.go b/internal/domain/expr_test.go new file mode 100644 index 00000000..f0a34504 --- /dev/null +++ b/internal/domain/expr_test.go @@ -0,0 +1,127 @@ +package domain + +import ( + "testing" +) + +func TestLogicalEval(t *testing.T) { + // 测试逻辑表达式 and + logicalExpr := LogicalExpr{ + Left: ConstExpr{ + Type: "const", + Value: true, + ValueType: "boolean", + }, + Op: And, + Right: ConstExpr{ + Type: "const", + Value: true, + ValueType: "boolean", + }, + } + result, err := logicalExpr.Eval(nil) + if err != nil { + t.Errorf("failed to evaluate logical expression: %v", err) + } + if result.Value != true { + t.Errorf("expected true, got %v", result) + } + + // 测试逻辑表达式 or + orExpr := LogicalExpr{ + Left: ConstExpr{ + Type: "const", + Value: true, + ValueType: "boolean", + }, + Op: Or, + Right: ConstExpr{ + Type: "const", + Value: true, + ValueType: "boolean", + }, + } + result, err = orExpr.Eval(nil) + if err != nil { + t.Errorf("failed to evaluate logical expression: %v", err) + } + if result.Value != true { + t.Errorf("expected true, got %v", result) + } +} + +func TestUnmarshalExpr(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want Expr + wantErr bool + }{ + { + name: "test1", + args: args{ + data: []byte(`{"left":{"left":{"selector":{"id":"ODnYSOXB6HQP2_vz6JcZE","name":"certificate.validated","type":"boolean"},"type":"var"},"op":"is","right":{"type":"const","value":true,"valueType":"boolean"},"type":"compare"},"op":"and","right":{"left":{"selector":{"id":"ODnYSOXB6HQP2_vz6JcZE","name":"certificate.daysLeft","type":"number"},"type":"var"},"op":"==","right":{"type":"const","value":2,"valueType":"number"},"type":"compare"},"type":"logical"}`), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := UnmarshalExpr(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalExpr() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got == nil { + t.Errorf("UnmarshalExpr() got = nil, want %v", tt.want) + return + } + }) + } +} + +func TestExpr_Eval(t *testing.T) { + type args struct { + variables map[string]map[string]any + data []byte + } + tests := []struct { + name string + args args + want *EvalResult + wantErr bool + }{ + { + name: "test1", + args: args{ + variables: map[string]map[string]any{ + "ODnYSOXB6HQP2_vz6JcZE": { + "certificate.validated": true, + "certificate.daysLeft": 2, + }, + }, + data: []byte(`{"left":{"left":{"selector":{"id":"ODnYSOXB6HQP2_vz6JcZE","name":"certificate.validated","type":"boolean"},"type":"var"},"op":"is","right":{"type":"const","value":true,"valueType":"boolean"},"type":"compare"},"op":"and","right":{"left":{"selector":{"id":"ODnYSOXB6HQP2_vz6JcZE","name":"certificate.daysLeft","type":"number"},"type":"var"},"op":"==","right":{"type":"const","value":2,"valueType":"number"},"type":"compare"},"type":"logical"}`), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := UnmarshalExpr(tt.args.data) + if err != nil { + t.Errorf("UnmarshalExpr() error = %v", err) + return + } + got, err := c.Eval(tt.args.variables) + t.Log("got:", got) + if (err != nil) != tt.wantErr { + t.Errorf("ConstExpr.Eval() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got.Value != true { + t.Errorf("ConstExpr.Eval() got = %v, want %v", got.Value, true) + } + }) + } +} diff --git a/internal/domain/workflow.go b/internal/domain/workflow.go index e1e72354..63237221 100644 --- a/internal/domain/workflow.go +++ b/internal/domain/workflow.go @@ -1,6 +1,7 @@ package domain import ( + "encoding/json" "time" maputil "github.com/usual2970/certimate/internal/pkg/utils/map" @@ -109,11 +110,13 @@ type WorkflowNodeConfigForNotify struct { } func (n *WorkflowNode) GetConfigForCondition() WorkflowNodeConfigForCondition { - raw := maputil.GetString(n.Config, "expression") - if raw == "" { + expression := n.Config["expression"] + if expression == nil { return WorkflowNodeConfigForCondition{} } + raw, _ := json.Marshal(expression) + expr, err := UnmarshalExpr([]byte(raw)) if err != nil { return WorkflowNodeConfigForCondition{} diff --git a/internal/workflow/dispatcher/invoker.go b/internal/workflow/dispatcher/invoker.go index a4de08e7..b6e4a4db 100644 --- a/internal/workflow/dispatcher/invoker.go +++ b/internal/workflow/dispatcher/invoker.go @@ -98,7 +98,9 @@ func (w *workflowInvoker) processNode(ctx context.Context, node *domain.Workflow procErr = processor.Process(ctx) if procErr != nil { - processor.GetLogger().Error(procErr.Error()) + if current.Type != domain.WorkflowNodeTypeCondition { + processor.GetLogger().Error(procErr.Error()) + } break } @@ -110,9 +112,12 @@ func (w *workflowInvoker) processNode(ctx context.Context, node *domain.Workflow break } - // TODO: 优化可读性 - if procErr != nil && current.Next != nil && current.Next.Type != domain.WorkflowNodeTypeExecuteResultBranch { + if procErr != nil && current.Type == domain.WorkflowNodeTypeCondition { + current = nil + procErr = nil + return nil + } else if procErr != nil && current.Next != nil && current.Next.Type != domain.WorkflowNodeTypeExecuteResultBranch { return procErr } else if procErr != nil && current.Next != nil && current.Next.Type == domain.WorkflowNodeTypeExecuteResultBranch { current = w.getBranchByType(current.Next.Branches, domain.WorkflowNodeTypeExecuteFailure) diff --git a/internal/workflow/node-processor/condition_node.go b/internal/workflow/node-processor/condition_node.go index f8ed228b..d90811d9 100644 --- a/internal/workflow/node-processor/condition_node.go +++ b/internal/workflow/node-processor/condition_node.go @@ -26,27 +26,26 @@ func (n *conditionNode) Process(ctx context.Context) error { nodeConfig := n.node.GetConfigForCondition() if nodeConfig.Expression == nil { + n.logger.Info("no condition found, continue to next node") return nil } + + rs, err := n.eval(ctx, nodeConfig.Expression) + if err != nil { + n.logger.Warn("failed to eval expression: " + err.Error()) + return err + } + + if rs.Value == false { + n.logger.Info("condition not met, skip this branch") + return errors.New("condition not met") + } + + n.logger.Info("condition met, continue to next node") return nil } -func (n *conditionNode) eval(ctx context.Context, expression domain.Expr) (any, error) { - switch expr:=expression.(type) { - case domain.CompareExpr: - left,err:= n.eval(ctx, expr.Left) - if err != nil { - return nil, err - } - right,err:= n.eval(ctx, expr.Right) - if err != nil { - return nil, err - } - - case domain.LogicalExpr: - case domain.NotExpr: - case domain.VarExpr: - case domain.ConstExpr: - } - return false, errors.New("unknown expression type") +func (n *conditionNode) eval(ctx context.Context, expression domain.Expr) (*domain.EvalResult, error) { + variables := GetNodeOutputs(ctx) + return expression.Eval(variables) } diff --git a/ui/src/components/workflow/node/ConditionNode.tsx b/ui/src/components/workflow/node/ConditionNode.tsx index d3f1defc..7b2cb554 100644 --- a/ui/src/components/workflow/node/ConditionNode.tsx +++ b/ui/src/components/workflow/node/ConditionNode.tsx @@ -57,7 +57,7 @@ const ConditionNode = ({ node, disabled, branchId, branchIndex }: ConditionNodeP break; } - const right: Expr = { type: "const", value: value }; + const right: Expr = { type: "const", value: value, valueType: t }; return { type: "compare", diff --git a/ui/src/components/workflow/node/ConditionNodeConfigForm.tsx b/ui/src/components/workflow/node/ConditionNodeConfigForm.tsx index f2d08253..e040dc78 100644 --- a/ui/src/components/workflow/node/ConditionNodeConfigForm.tsx +++ b/ui/src/components/workflow/node/ConditionNodeConfigForm.tsx @@ -318,6 +318,7 @@ const formToExpression = (values: ConditionNodeConfigFormFieldValues): Expr => { const right: Expr = { type: "const", value: rightValue, + valueType: type, }; return { diff --git a/ui/src/domain/workflow.ts b/ui/src/domain/workflow.ts index 9cd12287..05c936a7 100644 --- a/ui/src/domain/workflow.ts +++ b/ui/src/domain/workflow.ts @@ -238,7 +238,7 @@ export type ComparisonOperator = ">" | "<" | ">=" | "<=" | "==" | "!=" | "is"; export type LogicalOperator = "and" | "or" | "not"; -export type ConstExpr = { type: "const"; value: Value }; +export type ConstExpr = { type: "const"; value: Value; valueType: WorkflowNodeIoValueType }; export type VarExpr = { type: "var"; selector: WorkflowNodeIOValueSelector }; export type CompareExpr = { type: "compare"; op: ComparisonOperator; left: Expr; right: Expr }; export type LogicalExpr = { type: "logical"; op: LogicalOperator; left: Expr; right: Expr };