1package http2interop
2
3import (
4	"crypto/tls"
5	"crypto/x509"
6	"encoding/json"
7	"flag"
8	"fmt"
9	"io/ioutil"
10	"os"
11	"strconv"
12	"strings"
13	"testing"
14)
15
16var (
17	serverHost = flag.String("server_host", "", "The host to test")
18	serverPort = flag.Int("server_port", 443, "The port to test")
19	useTls     = flag.Bool("use_tls", true, "Should TLS tests be run")
20	testCase   = flag.String("test_case", "", "What test cases to run (tls, framing)")
21
22	// The rest of these are unused, but present to fulfill the client interface
23	serverHostOverride    = flag.String("server_host_override", "", "Unused")
24	useTestCa             = flag.Bool("use_test_ca", false, "Unused")
25	defaultServiceAccount = flag.String("default_service_account", "", "Unused")
26	oauthScope            = flag.String("oauth_scope", "", "Unused")
27	serviceAccountKeyFile = flag.String("service_account_key_file", "", "Unused")
28)
29
30func InteropCtx(t *testing.T) *HTTP2InteropCtx {
31	ctx := &HTTP2InteropCtx{
32		ServerHost:             *serverHost,
33		ServerPort:             *serverPort,
34		ServerHostnameOverride: *serverHostOverride,
35		UseTLS:                 *useTls,
36		UseTestCa:              *useTestCa,
37		T:                      t,
38	}
39
40	ctx.serverSpec = ctx.ServerHost
41	if ctx.ServerPort != -1 {
42		ctx.serverSpec += ":" + strconv.Itoa(ctx.ServerPort)
43	}
44	if ctx.ServerHostnameOverride == "" {
45		ctx.authority = ctx.ServerHost
46	} else {
47		ctx.authority = ctx.ServerHostnameOverride
48	}
49
50	if ctx.UseTestCa {
51		// It would be odd if useTestCa was true, but not useTls.  meh
52		certData, err := ioutil.ReadFile("src/core/tsi/test_creds/ca.pem")
53		if err != nil {
54			t.Fatal(err)
55		}
56
57		ctx.rootCAs = x509.NewCertPool()
58		if !ctx.rootCAs.AppendCertsFromPEM(certData) {
59			t.Fatal(fmt.Errorf("Unable to parse pem data"))
60		}
61	}
62
63	return ctx
64}
65
66func (ctx *HTTP2InteropCtx) Close() error {
67	// currently a noop
68	return nil
69}
70
71func TestSoonClientShortSettings(t *testing.T) {
72	defer Report(t)
73	if *testCase != "framing" {
74		t.SkipNow()
75	}
76	ctx := InteropCtx(t)
77	for i := 1; i <= 5; i++ {
78		err := testClientShortSettings(ctx, i)
79		matchError(t, err, "EOF")
80	}
81}
82
83func TestSoonShortPreface(t *testing.T) {
84	defer Report(t)
85	if *testCase != "framing" {
86		t.SkipNow()
87	}
88	ctx := InteropCtx(t)
89	for i := 0; i < len(Preface)-1; i++ {
90		err := testShortPreface(ctx, Preface[:i]+"X")
91		matchError(t, err, "EOF")
92	}
93}
94
95func TestSoonUnknownFrameType(t *testing.T) {
96	defer Report(t)
97	if *testCase != "framing" {
98		t.SkipNow()
99	}
100	ctx := InteropCtx(t)
101	if err := testUnknownFrameType(ctx); err != nil {
102		t.Fatal(err)
103	}
104}
105
106func TestSoonClientPrefaceWithStreamId(t *testing.T) {
107	defer Report(t)
108	if *testCase != "framing" {
109		t.SkipNow()
110	}
111	ctx := InteropCtx(t)
112	err := testClientPrefaceWithStreamId(ctx)
113	matchError(t, err, "EOF")
114}
115
116func TestSoonTLSApplicationProtocol(t *testing.T) {
117	defer Report(t)
118	if *testCase != "tls" {
119		t.SkipNow()
120	}
121	ctx := InteropCtx(t)
122	err := testTLSApplicationProtocol(ctx)
123	matchError(t, err, "EOF", "broken pipe")
124}
125
126func TestSoonTLSMaxVersion(t *testing.T) {
127	defer Report(t)
128	if *testCase != "tls" {
129		t.SkipNow()
130	}
131	ctx := InteropCtx(t)
132	err := testTLSMaxVersion(ctx, tls.VersionTLS11)
133	// TODO(carl-mastrangelo): maybe this should be some other error.  If the server picks
134	// the wrong protocol version, thats bad too.
135	matchError(t, err, "EOF", "server selected unsupported protocol")
136}
137
138func TestSoonTLSBadCipherSuites(t *testing.T) {
139	defer Report(t)
140	if *testCase != "tls" {
141		t.SkipNow()
142	}
143	ctx := InteropCtx(t)
144	err := testTLSBadCipherSuites(ctx)
145	matchError(t, err, "EOF", "Got goaway frame")
146}
147
148func matchError(t *testing.T, err error, matches ...string) {
149	if err == nil {
150		t.Fatal("Expected an error")
151	}
152	for _, s := range matches {
153		if strings.Contains(err.Error(), s) {
154			return
155		}
156	}
157	t.Fatalf("Error %v not in %+v", err, matches)
158}
159
160func TestMain(m *testing.M) {
161	flag.Parse()
162	m.Run()
163	var fatal bool
164	var any bool
165	for _, ci := range allCaseInfos.Cases {
166		if ci.Skipped {
167			continue
168		}
169		any = true
170		if !ci.Passed && ci.Fatal {
171			fatal = true
172		}
173	}
174
175	if err := json.NewEncoder(os.Stderr).Encode(&allCaseInfos); err != nil {
176		fmt.Println("Failed to encode", err)
177	}
178	var code int
179	if !any || fatal {
180		code = 1
181	}
182	os.Exit(code)
183}
184