Use context.Context for cancellation

This commit is contained in:
Julian Ospald 2017-09-11 12:12:43 +02:00
parent 771f7c6892
commit 17aa010ebf
3 changed files with 33 additions and 21 deletions

View File

@ -2,8 +2,8 @@ package main
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"flag"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -50,12 +50,8 @@ func main() {
// If all URLs take too long to load or return garbage, an empty JSON list // If all URLs take too long to load or return garbage, an empty JSON list
// is returned. // is returned.
func numbersHandler(w http.ResponseWriter, r *http.Request) { func numbersHandler(w http.ResponseWriter, r *http.Request) {
// timeout channel for the handler as a whole ctx, cancel := context.WithTimeout(context.Background(), MaxResponseTime)
timeout := make(chan bool, 1) defer cancel()
go func() {
time.Sleep(MaxResponseTime)
timeout <- true
}()
var rurl []string = r.URL.Query()["u"] var rurl []string = r.URL.Query()["u"]
// if no parameters, return 400 // if no parameters, return 400
@ -76,7 +72,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
wg.Add(1) wg.Add(1)
go func(url string) { go func(url string) {
defer wg.Done() defer wg.Done()
n, e := getNumbers(url) n, e := getNumbers(url, ctx)
if e == nil { if e == nil {
if n != nil && len(n) > 0 { if n != nil && len(n) > 0 {
@ -109,7 +105,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
done := false done := false
for done != true { for done != true {
select { select {
case <-timeout: case <-ctx.Done():
log.Printf("Waiting for URL took too long, finishing response anyway") log.Printf("Waiting for URL took too long, finishing response anyway")
finishResponse(w, sortedNumbers) finishResponse(w, sortedNumbers)
return return
@ -134,7 +130,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
// sort fallthrough, either the inputChan is currently "empty" // sort fallthrough, either the inputChan is currently "empty"
// or we fetched all URLs already // or we fetched all URLs already
go func(n []int) { go func(n []int) {
res, err := sort.SortedAndDedup(timeout, n) res, err := sort.SortedAndDedup(ctx, n)
if err != nil { if err != nil {
return return
} }
@ -145,7 +141,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
select { select {
case merged := <-sortChan: case merged := <-sortChan:
sortedNumbers = merged sortedNumbers = merged
case <-timeout: case <-ctx.Done():
log.Printf("Sorting took too long, finishing response anyway") log.Printf("Sorting took too long, finishing response anyway")
finishResponse(w, sortedNumbers) finishResponse(w, sortedNumbers)
return return
@ -168,7 +164,7 @@ func finishResponse(w http.ResponseWriter, numbers []int) {
// 'resp' is always nil if there was an error. Errors can // 'resp' is always nil if there was an error. Errors can
// be url parse errors, HTTP response errors, io errors from reading the // be url parse errors, HTTP response errors, io errors from reading the
// body or json decoding errors. // body or json decoding errors.
func getNumbers(rawurl string) (resp []int, err error) { func getNumbers(rawurl string, ctx context.Context) (resp []int, err error) {
// validate url // validate url
u_err := validateURL(rawurl) u_err := validateURL(rawurl)
if u_err != nil { if u_err != nil {
@ -176,10 +172,18 @@ func getNumbers(rawurl string) (resp []int, err error) {
} }
// retrieve response // retrieve response
r, r_err := http.Get(rawurl) client := &http.Client{}
req, r_err := http.NewRequest("GET", rawurl, nil)
if r_err != nil { if r_err != nil {
return nil, r_err return nil, r_err
} }
req = req.WithContext(ctx)
r, err := client.Do(req)
if err != nil {
return nil, err
}
if r.StatusCode != 200 { if r.StatusCode != 200 {
return nil, fmt.Errorf("HTTP: Status code is not 200, but %d", return nil, fmt.Errorf("HTTP: Status code is not 200, but %d",
r.StatusCode) r.StatusCode)

View File

@ -2,12 +2,13 @@
package sort package sort
import ( import (
"context"
"fmt" "fmt"
) )
// Mergesorts and deduplicates the list. // Mergesorts and deduplicates the list.
func SortedAndDedup(timeout <-chan bool, list []int) (res []int, err error) { func SortedAndDedup(ctx context.Context, list []int) (res []int, err error) {
sorted, err := Mergesort(timeout, list) sorted, err := Mergesort(ctx, list)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -39,7 +40,7 @@ func dedupSortedList(list []int) []int {
// is not modified. // is not modified.
// The algorithm is a bottom-up iterative version and not explained // The algorithm is a bottom-up iterative version and not explained
// in detail here. // in detail here.
func Mergesort(timeout <-chan bool, list []int) (res []int, err error) { func Mergesort(ctx context.Context, list []int) (res []int, err error) {
newList := append([]int{}, list...) newList := append([]int{}, list...)
temp := append([]int{}, list...) temp := append([]int{}, list...)
n := len(newList) n := len(newList)
@ -47,7 +48,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
for m := 1; m < (n - 1); m = 2 * m { for m := 1; m < (n - 1); m = 2 * m {
for i := 0; i < (n - 1); i += 2 * m { for i := 0; i < (n - 1); i += 2 * m {
select { select {
case <-timeout: case <-ctx.Done():
return nil, fmt.Errorf("Sorting timed out") return nil, fmt.Errorf("Sorting timed out")
default: default:
} }
@ -55,7 +56,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
mid := i + m - 1 mid := i + m - 1
to := min(i+2*m-1, n-1) to := min(i+2*m-1, n-1)
merge(timeout, newList, temp, from, mid, to) merge(ctx, newList, temp, from, mid, to)
} }
} }
@ -63,7 +64,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
} }
// The merge part of the mergesort. // The merge part of the mergesort.
func merge(timeout <-chan bool, list []int, temp []int, from int, mid int, to int) { func merge(ctx context.Context, list []int, temp []int, from int, mid int, to int) {
k := from k := from
i := from i := from
j := mid + 1 j := mid + 1

View File

@ -1,11 +1,15 @@
package sort package sort
import ( import (
"context"
"testing" "testing"
) )
// Test the mergesort and deduplication with a predefined set of slices. // Test the mergesort and deduplication with a predefined set of slices.
func TestSortAndDedup(t *testing.T) { func TestSortAndDedup(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
to_sort := [][]int{ to_sort := [][]int{
{}, {},
{7}, {7},
@ -29,7 +33,7 @@ func TestSortAndDedup(t *testing.T) {
} }
for i := range to_sort { for i := range to_sort {
sorted, _ := SortedAndDedup(make(chan bool, 1), to_sort[i]) sorted, _ := SortedAndDedup(ctx, to_sort[i])
if slice_equal(sorted, result[i]) != true { if slice_equal(sorted, result[i]) != true {
t.Errorf("Failure in sorting + dedup, expected %s got %s", result[i], sorted) t.Errorf("Failure in sorting + dedup, expected %s got %s", result[i], sorted)
} }
@ -39,6 +43,9 @@ func TestSortAndDedup(t *testing.T) {
// Test the mergesort with a predefined set of slices. // Test the mergesort with a predefined set of slices.
func TestSort(t *testing.T) { func TestSort(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// ok // ok
to_sort := [][]int{ to_sort := [][]int{
{}, {},
@ -61,7 +68,7 @@ func TestSort(t *testing.T) {
} }
for i := range to_sort { for i := range to_sort {
sorted, _ := Mergesort(make(chan bool, 1), to_sort[i]) sorted, _ := Mergesort(ctx, to_sort[i])
if slice_equal(sorted, result[i]) != true { if slice_equal(sorted, result[i]) != true {
t.Errorf("Failure in sorting, expected %s got %s", result[i], sorted) t.Errorf("Failure in sorting, expected %s got %s", result[i], sorted)
} }