Last active
November 30, 2024 14:24
-
-
Save lispyclouds/ba87671c05616f6a1bcd5ae36ce6a4be to your computer and use it in GitHub Desktop.
Go gather tasks with max concurrency control
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"iter" | |
"net/http" | |
"sync" | |
"time" | |
"github.com/spf13/cobra" | |
"golang.org/x/sync/semaphore" | |
) | |
type Result[T any] struct { | |
Val T | |
Err error | |
} | |
type Task[T any] func(ctx context.Context) (T, error) | |
func worker[T any](ctx context.Context, wg *sync.WaitGroup, in <-chan Task[T], out chan<- Result[T]) { | |
for task := range in { | |
res, err := task(ctx) | |
out <- Result[T]{Val: res, Err: err} | |
wg.Done() | |
} | |
} | |
func gatherWorker[T any](ctx context.Context, maxConcurrency int, tasks ...Task[T]) iter.Seq[Result[T]] { | |
return func(yield func(Result[T]) bool) { | |
var wg sync.WaitGroup | |
send := make(chan Task[T]) | |
recv := make(chan Result[T], len(tasks)) | |
for range maxConcurrency { | |
go worker(ctx, &wg, send, recv) | |
} | |
for _, task := range tasks { | |
send <- task | |
wg.Add(1) | |
} | |
wg.Wait() | |
close(send) | |
close(recv) | |
for result := range recv { | |
if !yield(result) { | |
return | |
} | |
} | |
} | |
} | |
func gatherSem[T any](ctx context.Context, maxConcurrency int, tasks ...Task[T]) iter.Seq[Result[T]] { | |
return func(yield func(Result[T]) bool) { | |
sem := semaphore.NewWeighted(int64(maxConcurrency)) | |
recv := make(chan Result[T], len(tasks)) | |
var wg sync.WaitGroup | |
for _, task := range tasks { | |
sem.Acquire(ctx, 1) | |
wg.Add(1) | |
go func() { | |
res, err := task(ctx) | |
recv <- Result[T]{Val: res, Err: err} | |
sem.Release(1) | |
wg.Done() | |
}() | |
} | |
wg.Wait() | |
close(recv) | |
for result := range recv { | |
if !yield(result) { | |
return | |
} | |
} | |
} | |
} | |
func playMain(cmd *cobra.Command, _ []string) { | |
apiUrl := "http://www.randomnumberapi.com/api/v1.0/random?min=1&max=100" // returns an array of ints | |
taskCount, _ := cmd.Flags().GetInt("tasks") | |
maxConcurrency, _ := cmd.Flags().GetInt("max-concurrency") | |
taskDelay, _ := cmd.Flags().GetInt("task-delay") | |
tasks := []Task[int64]{} | |
for range taskCount { | |
tasks = append(tasks, func(ctx context.Context) (int64, error) { | |
req, err := http.NewRequestWithContext(ctx, "GET", apiUrl, nil) | |
if err != nil { | |
return 0, err | |
} | |
res, err := http.DefaultClient.Do(req) | |
if err != nil { | |
return 0, err | |
} | |
var nums []int64 | |
dec := json.NewDecoder(res.Body) | |
if err = dec.Decode(&nums); err != nil { | |
return 0, err | |
} | |
n := nums[0] | |
fmt.Printf("Adding %d\n", n) | |
time.Sleep(time.Duration(taskDelay) * time.Second) | |
return n, nil | |
}) | |
} | |
if maxConcurrency == 0 { | |
maxConcurrency = len(tasks) | |
} | |
var sum int64 | |
for result := range gatherSem(context.Background(), maxConcurrency, tasks...) { | |
if err := result.Err; err != nil { | |
fmt.Printf("Error: %s\n", err.Error()) | |
return | |
} | |
sum += result.Val | |
} | |
fmt.Printf("Sum is: %d\n", sum) | |
} | |
func main() { | |
cmd := &cobra.Command{ | |
Use: "goplay", | |
Short: "goplay CLI", | |
Long: "CLI for GoPlay", | |
Run: playMain, | |
} | |
cmd.Flags().IntP("tasks", "t", 1, "number of tasks to spawn") | |
cmd.Flags().IntP("max-concurrency", "m", 0, "limit number of tasks to run concurrently, skip for unlimited") | |
cmd.Flags().IntP("task-delay", "d", 1, "added delay in seconds to each task") | |
cmd.Execute() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment