1 #region Copyright notice and license
2 
3 // Copyright 2015-2016 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #endregion
18 
19 using System;
20 using System.Collections.Generic;
21 using System.IO;
22 using System.Linq;
23 using System.Threading;
24 using System.Threading.Tasks;
25 using Google.Protobuf;
26 using Grpc.Core;
27 using Grpc.Core.Utils;
28 using Grpc.Testing;
29 using NUnit.Framework;
30 
31 namespace Grpc.IntegrationTesting
32 {
33     /// <summary>
34     /// Test SSL credentials where server authenticates client
35     /// and client authenticates the server.
36     /// </summary>
37     public class SslCredentialsTest
38     {
39         const string Host = "localhost";
40         Server server;
41         Channel channel;
42         TestService.TestServiceClient client;
43 
44         [OneTimeSetUp]
Init()45         public void Init()
46         {
47             var rootCert = File.ReadAllText(TestCredentials.ClientCertAuthorityPath);
48             var keyCertPair = new KeyCertificatePair(
49                 File.ReadAllText(TestCredentials.ServerCertChainPath),
50                 File.ReadAllText(TestCredentials.ServerPrivateKeyPath));
51 
52             var serverCredentials = new SslServerCredentials(new[] { keyCertPair }, rootCert, true);
53             var clientCredentials = new SslCredentials(rootCert, keyCertPair);
54 
55             // Disable SO_REUSEPORT to prevent https://github.com/grpc/grpc/issues/10755
56             server = new Server(new[] { new ChannelOption(ChannelOptions.SoReuseport, 0) })
57             {
58                 Services = { TestService.BindService(new SslCredentialsTestServiceImpl()) },
59                 Ports = { { Host, ServerPort.PickUnused, serverCredentials } }
60             };
61             server.Start();
62 
63             var options = new List<ChannelOption>
64             {
65                 new ChannelOption(ChannelOptions.SslTargetNameOverride, TestCredentials.DefaultHostOverride)
66             };
67 
68             channel = new Channel(Host, server.Ports.Single().BoundPort, clientCredentials, options);
69             client = new TestService.TestServiceClient(channel);
70         }
71 
72         [OneTimeTearDown]
Cleanup()73         public void Cleanup()
74         {
75             channel.ShutdownAsync().Wait();
76             server.ShutdownAsync().Wait();
77         }
78 
79         [Test]
AuthenticatedClientAndServer()80         public void AuthenticatedClientAndServer()
81         {
82             var response = client.UnaryCall(new SimpleRequest { ResponseSize = 10 });
83             Assert.AreEqual(10, response.Payload.Body.Length);
84         }
85 
86         [Test]
AuthContextIsPopulated()87         public async Task AuthContextIsPopulated()
88         {
89             var call = client.StreamingInputCall();
90             await call.RequestStream.CompleteAsync();
91             var response = await call.ResponseAsync;
92             Assert.AreEqual(12345, response.AggregatedPayloadSize);
93         }
94 
95         private class SslCredentialsTestServiceImpl : TestService.TestServiceBase
96         {
UnaryCall(SimpleRequest request, ServerCallContext context)97             public override Task<SimpleResponse> UnaryCall(SimpleRequest request, ServerCallContext context)
98             {
99                 return Task.FromResult(new SimpleResponse { Payload = CreateZerosPayload(request.ResponseSize) });
100             }
101 
StreamingInputCall(IAsyncStreamReader<StreamingInputCallRequest> requestStream, ServerCallContext context)102             public override async Task<StreamingInputCallResponse> StreamingInputCall(IAsyncStreamReader<StreamingInputCallRequest> requestStream, ServerCallContext context)
103             {
104                 var authContext = context.AuthContext;
105                 await requestStream.ForEachAsync(request => TaskUtils.CompletedTask);
106 
107                 Assert.IsTrue(authContext.IsPeerAuthenticated);
108                 Assert.AreEqual("x509_subject_alternative_name", authContext.PeerIdentityPropertyName);
109                 Assert.IsTrue(authContext.PeerIdentity.Count() > 0);
110                 Assert.AreEqual("ssl", authContext.FindPropertiesByName("transport_security_type").First().Value);
111 
112                 return new StreamingInputCallResponse { AggregatedPayloadSize = 12345 };
113             }
114 
CreateZerosPayload(int size)115             private static Payload CreateZerosPayload(int size)
116             {
117                 return new Payload { Body = ByteString.CopyFrom(new byte[size]) };
118             }
119         }
120     }
121 }
122