Created
November 6, 2022 12:59
-
-
Save valsteen/60b39247457d11325cc433afe938ba94 to your computer and use it in GitHub Desktop.
Chaining calls until something fails, with error mapping
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"errors" | |
"fmt" | |
"strconv" | |
"testing" | |
"github.com/stretchr/testify/require" | |
) | |
func identity[T any](t T) T { return t } | |
func Chain[I, R, R1 any]( | |
input I, | |
prelude func(I) (R, error), | |
success1 func(R) (R1, error), | |
) (R1, error) { | |
return ChainMapErr(input, prelude, identity[error], success1, identity[error]) | |
} | |
func Chain2[I, R, R1, R2 any]( | |
input I, | |
prelude func(I) (R, error), | |
success1 func(R) (R1, error), | |
success2 func(R1) (R2, error), | |
) (R2, error) { | |
return ChainMapErr2(input, prelude, identity[error], success1, identity[error], success2, identity[error]) | |
} | |
func ChainMapErr[I, R, R1 any]( | |
input I, | |
prelude func(I) (R, error), | |
mapError1 func(error) error, | |
success1 func(R) (R1, error), | |
mapError2 func(error) error, | |
) (R1, error) { | |
var zero R1 | |
first, err := prelude(input) | |
if err != nil { | |
return zero, mapError1(err) | |
} | |
ret, err := success1(first) | |
if err != nil { | |
return zero, mapError2(err) | |
} | |
return ret, nil | |
} | |
func ChainMapErr2[I, R, R1, R2 any]( | |
input I, | |
prelude func(I) (R, error), | |
mapError1 func(error) error, | |
success1 func(R) (R1, error), | |
mapError2 func(error) error, | |
success2 func(R1) (R2, error), | |
mapError3 func(error) error, | |
) (R2, error) { | |
first, err := ChainMapErr(input, prelude, mapError1, success1, mapError2) | |
var zero R2 | |
if err != nil { | |
return zero, err | |
} | |
ret2, err := success2(first) | |
if err != nil { | |
return zero, mapError3(err) | |
} | |
return ret2, nil | |
} | |
func process(num int) (int, error) { | |
if num < 0 { | |
return 0, errors.New("cannot be negative") | |
} | |
return num * 3, nil | |
} | |
func render(num int) (string, error) { | |
if num > 1000 { | |
return "", errors.New("too large") | |
} | |
return strconv.Itoa(num), nil | |
} | |
func Test(t *testing.T) { | |
type TestCase struct { | |
input string | |
result string | |
errString string | |
mappedErrString string | |
} | |
for _, testcase := range []TestCase{ | |
{ | |
"1", | |
"3", | |
"", | |
"", | |
}, | |
{ | |
"invalidint", | |
"", | |
"strconv.Atoi: parsing \"invalidint\": invalid syntax", | |
"error while parsing: strconv.Atoi: parsing \"invalidint\": invalid syntax", | |
}, | |
{ | |
"-3", | |
"", | |
"cannot be negative", | |
"error while processing: cannot be negative", | |
}, | |
{ | |
"340", | |
"", | |
"too large", | |
"error while rendering: too large", | |
}, | |
} { | |
result, err := Chain2(testcase.input, strconv.Atoi, process, render) | |
require.Equal(t, result, testcase.result) | |
if err != nil { | |
require.Equal(t, testcase.errString, err.Error()) | |
} else { | |
require.Empty(t, testcase.errString) | |
} | |
result, err = ChainMapErr2( | |
testcase.input, | |
strconv.Atoi, | |
func(e error) error { return fmt.Errorf("error while parsing: %w", e) }, | |
process, | |
func(e error) error { return fmt.Errorf("error while processing: %w", e) }, | |
render, | |
func(e error) error { return fmt.Errorf("error while rendering: %w", e) }, | |
) | |
require.Equal(t, result, testcase.result) | |
if err != nil { | |
require.Equal(t, testcase.mappedErrString, err.Error()) | |
} else { | |
require.Empty(t, testcase.mappedErrString) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment