Use context.Context for cancellation

This commit is contained in:
Julian Ospald 2017-09-11 12:12:43 +02:00
parent 771f7c6892
commit 17aa010ebf
3 changed files with 33 additions and 21 deletions

View File

@ -2,8 +2,8 @@ package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
@ -50,12 +50,8 @@ func main() {
// If all URLs take too long to load or return garbage, an empty JSON list
// is returned.
func numbersHandler(w http.ResponseWriter, r *http.Request) {
// timeout channel for the handler as a whole
timeout := make(chan bool, 1)
go func() {
time.Sleep(MaxResponseTime)
timeout <- true
}()
ctx, cancel := context.WithTimeout(context.Background(), MaxResponseTime)
defer cancel()
var rurl []string = r.URL.Query()["u"]
// if no parameters, return 400
@ -76,7 +72,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
wg.Add(1)
go func(url string) {
defer wg.Done()
n, e := getNumbers(url)
n, e := getNumbers(url, ctx)
if e == nil {
if n != nil && len(n) > 0 {
@ -109,7 +105,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
done := false
for done != true {
select {
case <-timeout:
case <-ctx.Done():
log.Printf("Waiting for URL took too long, finishing response anyway")
finishResponse(w, sortedNumbers)
return
@ -134,7 +130,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
// sort fallthrough, either the inputChan is currently "empty"
// or we fetched all URLs already
go func(n []int) {
res, err := sort.SortedAndDedup(timeout, n)
res, err := sort.SortedAndDedup(ctx, n)
if err != nil {
return
}
@ -145,7 +141,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
select {
case merged := <-sortChan:
sortedNumbers = merged
case <-timeout:
case <-ctx.Done():
log.Printf("Sorting took too long, finishing response anyway")
finishResponse(w, sortedNumbers)
return
@ -168,7 +164,7 @@ func finishResponse(w http.ResponseWriter, numbers []int) {
// 'resp' is always nil if there was an error. Errors can
// be url parse errors, HTTP response errors, io errors from reading the
// body or json decoding errors.
func getNumbers(rawurl string) (resp []int, err error) {
func getNumbers(rawurl string, ctx context.Context) (resp []int, err error) {
// validate url
u_err := validateURL(rawurl)
if u_err != nil {
@ -176,10 +172,18 @@ func getNumbers(rawurl string) (resp []int, err error) {
}
// retrieve response
r, r_err := http.Get(rawurl)
client := &http.Client{}
req, r_err := http.NewRequest("GET", rawurl, nil)
if r_err != nil {
return nil, r_err
}
req = req.WithContext(ctx)
r, err := client.Do(req)
if err != nil {
return nil, err
}
if r.StatusCode != 200 {
return nil, fmt.Errorf("HTTP: Status code is not 200, but %d",
r.StatusCode)

View File

@ -2,12 +2,13 @@
package sort
import (
"context"
"fmt"
)
// Mergesorts and deduplicates the list.
func SortedAndDedup(timeout <-chan bool, list []int) (res []int, err error) {
sorted, err := Mergesort(timeout, list)
func SortedAndDedup(ctx context.Context, list []int) (res []int, err error) {
sorted, err := Mergesort(ctx, list)
if err != nil {
return nil, err
}
@ -39,7 +40,7 @@ func dedupSortedList(list []int) []int {
// is not modified.
// The algorithm is a bottom-up iterative version and not explained
// in detail here.
func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
func Mergesort(ctx context.Context, list []int) (res []int, err error) {
newList := append([]int{}, list...)
temp := append([]int{}, list...)
n := len(newList)
@ -47,7 +48,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
for m := 1; m < (n - 1); m = 2 * m {
for i := 0; i < (n - 1); i += 2 * m {
select {
case <-timeout:
case <-ctx.Done():
return nil, fmt.Errorf("Sorting timed out")
default:
}
@ -55,7 +56,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
mid := i + m - 1
to := min(i+2*m-1, n-1)
merge(timeout, newList, temp, from, mid, to)
merge(ctx, newList, temp, from, mid, to)
}
}
@ -63,7 +64,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
}
// The merge part of the mergesort.
func merge(timeout <-chan bool, list []int, temp []int, from int, mid int, to int) {
func merge(ctx context.Context, list []int, temp []int, from int, mid int, to int) {
k := from
i := from
j := mid + 1

View File

@ -1,11 +1,15 @@
package sort
import (
"context"
"testing"
)
// Test the mergesort and deduplication with a predefined set of slices.
func TestSortAndDedup(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
to_sort := [][]int{
{},
{7},
@ -29,7 +33,7 @@ func TestSortAndDedup(t *testing.T) {
}
for i := range to_sort {
sorted, _ := SortedAndDedup(make(chan bool, 1), to_sort[i])
sorted, _ := SortedAndDedup(ctx, to_sort[i])
if slice_equal(sorted, result[i]) != true {
t.Errorf("Failure in sorting + dedup, expected %s got %s", result[i], sorted)
}
@ -39,6 +43,9 @@ func TestSortAndDedup(t *testing.T) {
// Test the mergesort with a predefined set of slices.
func TestSort(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// ok
to_sort := [][]int{
{},
@ -61,7 +68,7 @@ func TestSort(t *testing.T) {
}
for i := range to_sort {
sorted, _ := Mergesort(make(chan bool, 1), to_sort[i])
sorted, _ := Mergesort(ctx, to_sort[i])
if slice_equal(sorted, result[i]) != true {
t.Errorf("Failure in sorting, expected %s got %s", result[i], sorted)
}