Use context.Context for cancellation
This commit is contained in:
parent
771f7c6892
commit
17aa010ebf
@ -2,8 +2,8 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
@ -50,12 +50,8 @@ func main() {
|
|||||||
// If all URLs take too long to load or return garbage, an empty JSON list
|
// If all URLs take too long to load or return garbage, an empty JSON list
|
||||||
// is returned.
|
// is returned.
|
||||||
func numbersHandler(w http.ResponseWriter, r *http.Request) {
|
func numbersHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
// timeout channel for the handler as a whole
|
ctx, cancel := context.WithTimeout(context.Background(), MaxResponseTime)
|
||||||
timeout := make(chan bool, 1)
|
defer cancel()
|
||||||
go func() {
|
|
||||||
time.Sleep(MaxResponseTime)
|
|
||||||
timeout <- true
|
|
||||||
}()
|
|
||||||
|
|
||||||
var rurl []string = r.URL.Query()["u"]
|
var rurl []string = r.URL.Query()["u"]
|
||||||
// if no parameters, return 400
|
// if no parameters, return 400
|
||||||
@ -76,7 +72,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(url string) {
|
go func(url string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
n, e := getNumbers(url)
|
n, e := getNumbers(url, ctx)
|
||||||
|
|
||||||
if e == nil {
|
if e == nil {
|
||||||
if n != nil && len(n) > 0 {
|
if n != nil && len(n) > 0 {
|
||||||
@ -109,7 +105,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
done := false
|
done := false
|
||||||
for done != true {
|
for done != true {
|
||||||
select {
|
select {
|
||||||
case <-timeout:
|
case <-ctx.Done():
|
||||||
log.Printf("Waiting for URL took too long, finishing response anyway")
|
log.Printf("Waiting for URL took too long, finishing response anyway")
|
||||||
finishResponse(w, sortedNumbers)
|
finishResponse(w, sortedNumbers)
|
||||||
return
|
return
|
||||||
@ -134,7 +130,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// sort fallthrough, either the inputChan is currently "empty"
|
// sort fallthrough, either the inputChan is currently "empty"
|
||||||
// or we fetched all URLs already
|
// or we fetched all URLs already
|
||||||
go func(n []int) {
|
go func(n []int) {
|
||||||
res, err := sort.SortedAndDedup(timeout, n)
|
res, err := sort.SortedAndDedup(ctx, n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -145,7 +141,7 @@ func numbersHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
select {
|
select {
|
||||||
case merged := <-sortChan:
|
case merged := <-sortChan:
|
||||||
sortedNumbers = merged
|
sortedNumbers = merged
|
||||||
case <-timeout:
|
case <-ctx.Done():
|
||||||
log.Printf("Sorting took too long, finishing response anyway")
|
log.Printf("Sorting took too long, finishing response anyway")
|
||||||
finishResponse(w, sortedNumbers)
|
finishResponse(w, sortedNumbers)
|
||||||
return
|
return
|
||||||
@ -168,7 +164,7 @@ func finishResponse(w http.ResponseWriter, numbers []int) {
|
|||||||
// 'resp' is always nil if there was an error. Errors can
|
// 'resp' is always nil if there was an error. Errors can
|
||||||
// be url parse errors, HTTP response errors, io errors from reading the
|
// be url parse errors, HTTP response errors, io errors from reading the
|
||||||
// body or json decoding errors.
|
// body or json decoding errors.
|
||||||
func getNumbers(rawurl string) (resp []int, err error) {
|
func getNumbers(rawurl string, ctx context.Context) (resp []int, err error) {
|
||||||
// validate url
|
// validate url
|
||||||
u_err := validateURL(rawurl)
|
u_err := validateURL(rawurl)
|
||||||
if u_err != nil {
|
if u_err != nil {
|
||||||
@ -176,10 +172,18 @@ func getNumbers(rawurl string) (resp []int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// retrieve response
|
// retrieve response
|
||||||
r, r_err := http.Get(rawurl)
|
client := &http.Client{}
|
||||||
|
req, r_err := http.NewRequest("GET", rawurl, nil)
|
||||||
if r_err != nil {
|
if r_err != nil {
|
||||||
return nil, r_err
|
return nil, r_err
|
||||||
}
|
}
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
r, err := client.Do(req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if r.StatusCode != 200 {
|
if r.StatusCode != 200 {
|
||||||
return nil, fmt.Errorf("HTTP: Status code is not 200, but %d",
|
return nil, fmt.Errorf("HTTP: Status code is not 200, but %d",
|
||||||
r.StatusCode)
|
r.StatusCode)
|
||||||
|
@ -2,12 +2,13 @@
|
|||||||
package sort
|
package sort
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Mergesorts and deduplicates the list.
|
// Mergesorts and deduplicates the list.
|
||||||
func SortedAndDedup(timeout <-chan bool, list []int) (res []int, err error) {
|
func SortedAndDedup(ctx context.Context, list []int) (res []int, err error) {
|
||||||
sorted, err := Mergesort(timeout, list)
|
sorted, err := Mergesort(ctx, list)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -39,7 +40,7 @@ func dedupSortedList(list []int) []int {
|
|||||||
// is not modified.
|
// is not modified.
|
||||||
// The algorithm is a bottom-up iterative version and not explained
|
// The algorithm is a bottom-up iterative version and not explained
|
||||||
// in detail here.
|
// in detail here.
|
||||||
func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
|
func Mergesort(ctx context.Context, list []int) (res []int, err error) {
|
||||||
newList := append([]int{}, list...)
|
newList := append([]int{}, list...)
|
||||||
temp := append([]int{}, list...)
|
temp := append([]int{}, list...)
|
||||||
n := len(newList)
|
n := len(newList)
|
||||||
@ -47,7 +48,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
|
|||||||
for m := 1; m < (n - 1); m = 2 * m {
|
for m := 1; m < (n - 1); m = 2 * m {
|
||||||
for i := 0; i < (n - 1); i += 2 * m {
|
for i := 0; i < (n - 1); i += 2 * m {
|
||||||
select {
|
select {
|
||||||
case <-timeout:
|
case <-ctx.Done():
|
||||||
return nil, fmt.Errorf("Sorting timed out")
|
return nil, fmt.Errorf("Sorting timed out")
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@ -55,7 +56,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
|
|||||||
mid := i + m - 1
|
mid := i + m - 1
|
||||||
to := min(i+2*m-1, n-1)
|
to := min(i+2*m-1, n-1)
|
||||||
|
|
||||||
merge(timeout, newList, temp, from, mid, to)
|
merge(ctx, newList, temp, from, mid, to)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ func Mergesort(timeout <-chan bool, list []int) (res []int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// The merge part of the mergesort.
|
// The merge part of the mergesort.
|
||||||
func merge(timeout <-chan bool, list []int, temp []int, from int, mid int, to int) {
|
func merge(ctx context.Context, list []int, temp []int, from int, mid int, to int) {
|
||||||
k := from
|
k := from
|
||||||
i := from
|
i := from
|
||||||
j := mid + 1
|
j := mid + 1
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
package sort
|
package sort
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Test the mergesort and deduplication with a predefined set of slices.
|
// Test the mergesort and deduplication with a predefined set of slices.
|
||||||
func TestSortAndDedup(t *testing.T) {
|
func TestSortAndDedup(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
to_sort := [][]int{
|
to_sort := [][]int{
|
||||||
{},
|
{},
|
||||||
{7},
|
{7},
|
||||||
@ -29,7 +33,7 @@ func TestSortAndDedup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := range to_sort {
|
for i := range to_sort {
|
||||||
sorted, _ := SortedAndDedup(make(chan bool, 1), to_sort[i])
|
sorted, _ := SortedAndDedup(ctx, to_sort[i])
|
||||||
if slice_equal(sorted, result[i]) != true {
|
if slice_equal(sorted, result[i]) != true {
|
||||||
t.Errorf("Failure in sorting + dedup, expected %s got %s", result[i], sorted)
|
t.Errorf("Failure in sorting + dedup, expected %s got %s", result[i], sorted)
|
||||||
}
|
}
|
||||||
@ -39,6 +43,9 @@ func TestSortAndDedup(t *testing.T) {
|
|||||||
|
|
||||||
// Test the mergesort with a predefined set of slices.
|
// Test the mergesort with a predefined set of slices.
|
||||||
func TestSort(t *testing.T) {
|
func TestSort(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// ok
|
// ok
|
||||||
to_sort := [][]int{
|
to_sort := [][]int{
|
||||||
{},
|
{},
|
||||||
@ -61,7 +68,7 @@ func TestSort(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := range to_sort {
|
for i := range to_sort {
|
||||||
sorted, _ := Mergesort(make(chan bool, 1), to_sort[i])
|
sorted, _ := Mergesort(ctx, to_sort[i])
|
||||||
if slice_equal(sorted, result[i]) != true {
|
if slice_equal(sorted, result[i]) != true {
|
||||||
t.Errorf("Failure in sorting, expected %s got %s", result[i], sorted)
|
t.Errorf("Failure in sorting, expected %s got %s", result[i], sorted)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user