diff --git a/src/numbers/numbers.go b/src/numbers/numbers.go index 90d2c80..c82c9c4 100644 --- a/src/numbers/numbers.go +++ b/src/numbers/numbers.go @@ -2,8 +2,8 @@ package main import ( "bytes" + "context" "encoding/json" - "flag" "fmt" "io/ioutil" "log" @@ -50,12 +50,8 @@ func main() { // If all URLs take too long to load or return garbage, an empty JSON list // is returned. func numbersHandler(w http.ResponseWriter, r *http.Request) { - // timeout channel for the handler as a whole - timeout := make(chan bool, 1) - go func() { - time.Sleep(MaxResponseTime) - timeout <- true - }() + ctx, cancel := context.WithTimeout(context.Background(), MaxResponseTime) + defer cancel() var rurl []string = r.URL.Query()["u"] // if no parameters, return 400 @@ -76,7 +72,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) { wg.Add(1) go func(url string) { defer wg.Done() - n, e := getNumbers(url) + n, e := getNumbers(url, ctx) if e == nil { if n != nil && len(n) > 0 { @@ -109,7 +105,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) { done := false for done != true { select { - case <-timeout: + case <-ctx.Done(): log.Printf("Waiting for URL took too long, finishing response anyway") finishResponse(w, sortedNumbers) return @@ -134,7 +130,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) { // sort fallthrough, either the inputChan is currently "empty" // or we fetched all URLs already go func(n []int) { - res, err := sort.SortedAndDedup(timeout, n) + res, err := sort.SortedAndDedup(ctx, n) if err != nil { return } @@ -145,7 +141,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) { select { case merged := <-sortChan: sortedNumbers = merged - case <-timeout: + case <-ctx.Done(): log.Printf("Sorting took too long, finishing response anyway") finishResponse(w, sortedNumbers) return @@ -168,7 +164,7 @@ func finishResponse(w http.ResponseWriter, numbers []int) { // 'resp' is always nil if there was an error. Errors can // be url parse errors, HTTP response errors, io errors from reading the // body or json decoding errors. -func getNumbers(rawurl string) (resp []int, err error) { +func getNumbers(rawurl string, ctx context.Context) (resp []int, err error) { // validate url u_err := validateURL(rawurl) if u_err != nil { @@ -176,10 +172,18 @@ func getNumbers(rawurl string) (resp []int, err error) { } // retrieve response - r, r_err := http.Get(rawurl) + client := &http.Client{} + req, r_err := http.NewRequest("GET", rawurl, nil) if r_err != nil { return nil, r_err } + req = req.WithContext(ctx) + r, err := client.Do(req) + + if err != nil { + return nil, err + } + if r.StatusCode != 200 { return nil, fmt.Errorf("HTTP: Status code is not 200, but %d", r.StatusCode) diff --git a/src/numbers/sort/sort.go b/src/numbers/sort/sort.go index da90cb2..6c35248 100644 --- a/src/numbers/sort/sort.go +++ b/src/numbers/sort/sort.go @@ -2,12 +2,13 @@ package sort import ( + "context" "fmt" ) // Mergesorts and deduplicates the list. -func SortedAndDedup(timeout <-chan bool, list []int) (res []int, err error) { - sorted, err := Mergesort(timeout, list) +func SortedAndDedup(ctx context.Context, list []int) (res []int, err error) { + sorted, err := Mergesort(ctx, list) if err != nil { return nil, err } @@ -39,7 +40,7 @@ func dedupSortedList(list []int) []int { // is not modified. // The algorithm is a bottom-up iterative version and not explained // in detail here. -func Mergesort(timeout <-chan bool, list []int) (res []int, err error) { +func Mergesort(ctx context.Context, list []int) (res []int, err error) { newList := append([]int{}, list...) temp := append([]int{}, list...) n := len(newList) @@ -47,7 +48,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) { for m := 1; m < (n - 1); m = 2 * m { for i := 0; i < (n - 1); i += 2 * m { select { - case <-timeout: + case <-ctx.Done(): return nil, fmt.Errorf("Sorting timed out") default: } @@ -55,7 +56,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) { mid := i + m - 1 to := min(i+2*m-1, n-1) - merge(timeout, newList, temp, from, mid, to) + merge(ctx, newList, temp, from, mid, to) } } @@ -63,7 +64,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) { } // The merge part of the mergesort. -func merge(timeout <-chan bool, list []int, temp []int, from int, mid int, to int) { +func merge(ctx context.Context, list []int, temp []int, from int, mid int, to int) { k := from i := from j := mid + 1 diff --git a/src/numbers/sort/sort_test.go b/src/numbers/sort/sort_test.go index 7ce4092..32595fd 100644 --- a/src/numbers/sort/sort_test.go +++ b/src/numbers/sort/sort_test.go @@ -1,11 +1,15 @@ package sort import ( + "context" "testing" ) // Test the mergesort and deduplication with a predefined set of slices. func TestSortAndDedup(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + to_sort := [][]int{ {}, {7}, @@ -29,7 +33,7 @@ func TestSortAndDedup(t *testing.T) { } for i := range to_sort { - sorted, _ := SortedAndDedup(make(chan bool, 1), to_sort[i]) + sorted, _ := SortedAndDedup(ctx, to_sort[i]) if slice_equal(sorted, result[i]) != true { t.Errorf("Failure in sorting + dedup, expected %s got %s", result[i], sorted) } @@ -39,6 +43,9 @@ func TestSortAndDedup(t *testing.T) { // Test the mergesort with a predefined set of slices. func TestSort(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // ok to_sort := [][]int{ {}, @@ -61,7 +68,7 @@ func TestSort(t *testing.T) { } for i := range to_sort { - sorted, _ := Mergesort(make(chan bool, 1), to_sort[i]) + sorted, _ := Mergesort(ctx, to_sort[i]) if slice_equal(sorted, result[i]) != true { t.Errorf("Failure in sorting, expected %s got %s", result[i], sorted) }