2025-06-13 13:37:14 +08:00

306 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package dispatcher
import (
"context"
"errors"
"fmt"
"log"
"os"
"runtime"
"runtime/debug"
"strconv"
"sync"
"time"
"github.com/usual2970/certimate/internal/app"
"github.com/usual2970/certimate/internal/domain"
sliceutil "github.com/usual2970/certimate/internal/pkg/utils/slice"
)
var maxWorkers = 1
func init() {
envMaxWorkers := os.Getenv("CERTIMATE_WORKFLOW_MAX_WORKERS")
if n, err := strconv.Atoi(envMaxWorkers); err != nil && n > 0 {
maxWorkers = n
} else {
maxWorkers = runtime.GOMAXPROCS(0)
if maxWorkers == 0 {
maxWorkers = max(1, runtime.NumCPU())
}
}
}
type workflowWorker struct {
Data *WorkflowWorkerData
Cancel context.CancelFunc
}
type WorkflowWorkerData struct {
WorkflowId string
WorkflowContent *domain.WorkflowNode
RunId string
}
type WorkflowDispatcher struct {
semaphore chan struct{}
queue []*WorkflowWorkerData
queueMutex sync.Mutex
workers map[string]*workflowWorker // key: WorkflowId
workerIdMap map[string]string // key: RunId, value: WorkflowId
workerMutex sync.Mutex
chWork chan *WorkflowWorkerData
chCandi chan struct{}
wg sync.WaitGroup
workflowRepo workflowRepository
workflowRunRepo workflowRunRepository
workflowLogRepo workflowLogRepository
}
func newWorkflowDispatcher(workflowRepo workflowRepository, workflowRunRepo workflowRunRepository, workflowLogRepo workflowLogRepository) *WorkflowDispatcher {
dispatcher := &WorkflowDispatcher{
semaphore: make(chan struct{}, maxWorkers),
queue: make([]*WorkflowWorkerData, 0),
queueMutex: sync.Mutex{},
workers: make(map[string]*workflowWorker),
workerIdMap: make(map[string]string),
workerMutex: sync.Mutex{},
chWork: make(chan *WorkflowWorkerData),
chCandi: make(chan struct{}, 1),
workflowRepo: workflowRepo,
workflowRunRepo: workflowRunRepo,
workflowLogRepo: workflowLogRepo,
}
go func() {
for {
select {
case <-dispatcher.chWork:
dispatcher.dequeueWorker()
case <-dispatcher.chCandi:
dispatcher.dequeueWorker()
}
}
}()
return dispatcher
}
func (d *WorkflowDispatcher) Dispatch(data *WorkflowWorkerData) {
if data == nil {
panic("worker data is nil")
}
d.enqueueWorker(data)
select {
case d.chWork <- data:
default:
}
}
func (d *WorkflowDispatcher) Cancel(runId string) {
hasWorker := false
// 取消正在执行的 WorkflowRun
d.workerMutex.Lock()
if workflowId, ok := d.workerIdMap[runId]; ok {
if worker, ok := d.workers[workflowId]; ok {
hasWorker = true
worker.Cancel()
delete(d.workers, workflowId)
delete(d.workerIdMap, runId)
}
}
d.workerMutex.Unlock()
// 移除排队中的 WorkflowRun
d.queueMutex.Lock()
d.queue = sliceutil.Filter(d.queue, func(d *WorkflowWorkerData) bool {
return d.RunId != runId
})
d.queueMutex.Unlock()
// 已挂起,查询 WorkflowRun 并更新其状态为 Canceled
if !hasWorker {
if run, err := d.workflowRunRepo.GetById(context.Background(), runId); err == nil {
if run.Status == domain.WorkflowRunStatusTypePending || run.Status == domain.WorkflowRunStatusTypeRunning {
run.Status = domain.WorkflowRunStatusTypeCanceled
d.workflowRunRepo.Save(context.Background(), run)
}
}
}
}
func (d *WorkflowDispatcher) Shutdown() {
// 清空排队中的 WorkflowRun
d.queueMutex.Lock()
d.queue = make([]*WorkflowWorkerData, 0)
d.queueMutex.Unlock()
// 等待所有正在执行的 WorkflowRun 完成
d.workerMutex.Lock()
for _, worker := range d.workers {
worker.Cancel()
delete(d.workers, worker.Data.WorkflowId)
delete(d.workerIdMap, worker.Data.RunId)
}
d.workerMutex.Unlock()
d.wg.Wait()
}
func (d *WorkflowDispatcher) enqueueWorker(data *WorkflowWorkerData) {
d.queueMutex.Lock()
defer d.queueMutex.Unlock()
d.queue = append(d.queue, data)
}
func (d *WorkflowDispatcher) dequeueWorker() {
for {
select {
case d.semaphore <- struct{}{}:
default:
// 达到最大并发数
return
}
d.queueMutex.Lock()
if len(d.queue) == 0 {
d.queueMutex.Unlock()
<-d.semaphore
return
}
data := d.queue[0]
d.queue = d.queue[1:]
d.queueMutex.Unlock()
// 检查是否有相同 WorkflowId 的 WorkflowRun 正在执行
// 如果有,则重新排队,以保证同一个工作流同一时间内只有一个正在执行
// 即不同 WorkflowId 的任务并行化,相同 WorkflowId 的任务串行化
d.workerMutex.Lock()
if _, exists := d.workers[data.WorkflowId]; exists {
d.queueMutex.Lock()
d.queue = append(d.queue, data)
d.queueMutex.Unlock()
d.workerMutex.Unlock()
<-d.semaphore
continue
}
ctx, cancel := context.WithCancel(context.Background())
d.workers[data.WorkflowId] = &workflowWorker{data, cancel}
d.workerIdMap[data.RunId] = data.WorkflowId
d.workerMutex.Unlock()
d.wg.Add(1)
go d.work(ctx, data)
}
}
func (d *WorkflowDispatcher) work(ctx context.Context, data *WorkflowWorkerData) {
var run *domain.WorkflowRun
var err error
defer func() {
// 捕获 panic避免影响其他工作流的执行
if r := recover(); r != nil {
log.Default().Println("WorkflowId:", data.WorkflowId, "RunId:", data.RunId)
log.Default().Println("Recovered from panic:", r)
log.Default().Println("Stack trace:", string(debug.Stack()))
if run != nil {
run.Status = domain.WorkflowRunStatusTypeFailed
run.EndedAt = time.Now()
run.Error = fmt.Sprintf("workflow run panic: %v", r)
if _, err := d.workflowRunRepo.Save(ctx, run); err != nil {
log.Default().Println("Failed to save workflow run after panic:", err)
}
}
}
<-d.semaphore
d.workerMutex.Lock()
delete(d.workers, data.WorkflowId)
delete(d.workerIdMap, data.RunId)
d.workerMutex.Unlock()
d.wg.Done()
// 尝试取出排队中的其他 WorkflowRun 继续执行
select {
case d.chCandi <- struct{}{}:
default:
}
}()
// 查询 WorkflowRun
run, err = d.workflowRunRepo.GetById(ctx, data.RunId)
if err != nil {
if !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
app.GetLogger().Error(fmt.Sprintf("failed to get workflow run #%s", data.RunId), "err", err)
}
return
} else if run.Status != domain.WorkflowRunStatusTypePending {
return
} else if ctx.Err() != nil {
run.Status = domain.WorkflowRunStatusTypeCanceled
d.workflowRunRepo.Save(ctx, run)
return
}
// 更新 WorkflowRun 状态为 Running
run.Status = domain.WorkflowRunStatusTypeRunning
if _, err := d.workflowRunRepo.Save(ctx, run); err != nil {
if !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
panic(err)
}
return
}
// 执行工作流
invoker := newWorkflowInvokerWithData(d.workflowLogRepo, data)
if runErr := invoker.Invoke(ctx); runErr != nil {
if errors.Is(runErr, context.Canceled) {
run.Status = domain.WorkflowRunStatusTypeCanceled
} else {
run.Status = domain.WorkflowRunStatusTypeFailed
run.EndedAt = time.Now()
run.Error = runErr.Error()
}
if _, err := d.workflowRunRepo.Save(ctx, run); err != nil {
if !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
panic(err)
}
}
return
}
// 更新 WorkflowRun 状态为 Succeeded/Failed
run.EndedAt = time.Now()
run.Error = invoker.GetLogs().ErrorString()
if run.Error == "" {
run.Status = domain.WorkflowRunStatusTypeSucceeded
} else {
run.Status = domain.WorkflowRunStatusTypeFailed
}
if _, err := d.workflowRunRepo.Save(ctx, run); err != nil {
if !(errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) {
panic(err)
}
}
}