1 /*
2  * Copyright (c) 1996, 2012, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 
27 package sun.security.ssl;
28 
29 import java.io.OutputStream;
30 import java.io.IOException;
31 import java.security.MessageDigest;
32 
33 /**
34  * Output stream for handshake data.  This is used only internally
35  * to the SSL classes.
36  *
37  * MT note:  one thread at a time is presumed be writing handshake
38  * messages, but (after initial connection setup) it's possible to
39  * have other threads reading/writing application data.  It's the
40  * SSLSocketImpl class that synchronizes record writes.
41  *
42  * @author  David Brownell
43  */
44 public class HandshakeOutStream extends OutputStream {
45 
46     private SSLSocketImpl socket;
47     private SSLEngineImpl engine;
48 
49     OutputRecord r;
50 
HandshakeOutStream(ProtocolVersion protocolVersion, ProtocolVersion helloVersion, HandshakeHash handshakeHash, SSLSocketImpl socket)51     HandshakeOutStream(ProtocolVersion protocolVersion,
52             ProtocolVersion helloVersion, HandshakeHash handshakeHash,
53             SSLSocketImpl socket) {
54         this.socket = socket;
55         r = new OutputRecord(Record.ct_handshake);
56         init(protocolVersion, helloVersion, handshakeHash);
57     }
58 
HandshakeOutStream(ProtocolVersion protocolVersion, ProtocolVersion helloVersion, HandshakeHash handshakeHash, SSLEngineImpl engine)59     HandshakeOutStream(ProtocolVersion protocolVersion,
60             ProtocolVersion helloVersion, HandshakeHash handshakeHash,
61             SSLEngineImpl engine) {
62         this.engine = engine;
63         r = new EngineOutputRecord(Record.ct_handshake, engine);
64         init(protocolVersion, helloVersion, handshakeHash);
65     }
66 
init(ProtocolVersion protocolVersion, ProtocolVersion helloVersion, HandshakeHash handshakeHash)67     private void init(ProtocolVersion protocolVersion,
68             ProtocolVersion helloVersion, HandshakeHash handshakeHash) {
69         r.setVersion(protocolVersion);
70         r.setHelloVersion(helloVersion);
71         r.setHandshakeHash(handshakeHash);
72     }
73 
74 
75     /*
76      * Update the handshake data hashes ... mostly for use after a
77      * client cert has been sent, so the cert verify message can be
78      * constructed correctly yet without forcing extra I/O.  In all
79      * other cases, automatic hash calculation suffices.
80      */
doHashes()81     void doHashes() {
82         r.doHashes();
83     }
84 
85     /*
86      * Write some data out onto the stream ... buffers as much as possible.
87      * Hashes are updated automatically if something gets flushed to the
88      * network (e.g. a big cert message etc).
89      */
write(byte buf[], int off, int len)90     public void write(byte buf[], int off, int len) throws IOException {
91         while (len > 0) {
92             int howmuch = Math.min(len, r.availableDataBytes());
93 
94             if (howmuch == 0) {
95                 flush();
96             } else {
97                 r.write(buf, off, howmuch);
98                 off += howmuch;
99                 len -= howmuch;
100             }
101         }
102     }
103 
104     /*
105      * write-a-byte
106      */
write(int i)107     public void write(int i) throws IOException {
108         if (r.availableDataBytes() < 1) {
109             flush();
110         }
111         r.write(i);
112     }
113 
flush()114     public void flush() throws IOException {
115         if (socket != null) {
116             try {
117                 socket.writeRecord(r);
118             } catch (IOException e) {
119                 // Had problems writing; check if there was an
120                 // alert from peer. If alert received, waitForClose
121                 // will throw an exception for the alert
122                 socket.waitForClose(true);
123 
124                 // No alert was received, just rethrow exception
125                 throw e;
126             }
127         } else {  // engine != null
128             /*
129              * Even if record might be empty, flush anyway in case
130              * there is a finished handshake message that we need
131              * to queue.
132              */
133             engine.writeRecord((EngineOutputRecord)r);
134         }
135     }
136 
137     /*
138      * Tell the OutputRecord that a finished message was
139      * contained either in this record or the one immeiately
140      * preceeding it.  We need to reliably pass back notifications
141      * that a finish message occured.
142      */
setFinishedMsg()143     void setFinishedMsg() {
144         assert(socket == null);
145 
146         ((EngineOutputRecord)r).setFinishedMsg();
147     }
148 
149     /*
150      * Put integers encoded in standard 8, 16, 24, and 32 bit
151      * big endian formats. Note that OutputStream.write(int) only
152      * writes the least significant 8 bits and ignores the rest.
153      */
154 
putInt8(int i)155     void putInt8(int i) throws IOException {
156         checkOverflow(i, Record.OVERFLOW_OF_INT08);
157         r.write(i);
158     }
159 
putInt16(int i)160     void putInt16(int i) throws IOException {
161         checkOverflow(i, Record.OVERFLOW_OF_INT16);
162         if (r.availableDataBytes() < 2) {
163             flush();
164         }
165         r.write(i >> 8);
166         r.write(i);
167     }
168 
putInt24(int i)169     void putInt24(int i) throws IOException {
170         checkOverflow(i, Record.OVERFLOW_OF_INT24);
171         if (r.availableDataBytes() < 3) {
172             flush();
173         }
174         r.write(i >> 16);
175         r.write(i >> 8);
176         r.write(i);
177     }
178 
putInt32(int i)179     void putInt32(int i) throws IOException {
180         if (r.availableDataBytes() < 4) {
181             flush();
182         }
183         r.write(i >> 24);
184         r.write(i >> 16);
185         r.write(i >> 8);
186         r.write(i);
187     }
188 
189     /*
190      * Put byte arrays with length encoded as 8, 16, 24 bit
191      * integers in big-endian format.
192      */
putBytes8(byte b[])193     void putBytes8(byte b[]) throws IOException {
194         if (b == null) {
195             putInt8(0);
196             return;
197         } else {
198             checkOverflow(b.length, Record.OVERFLOW_OF_INT08);
199         }
200         putInt8(b.length);
201         write(b, 0, b.length);
202     }
203 
putBytes16(byte b[])204     public void putBytes16(byte b[]) throws IOException {
205         if (b == null) {
206             putInt16(0);
207             return;
208         } else {
209             checkOverflow(b.length, Record.OVERFLOW_OF_INT16);
210         }
211         putInt16(b.length);
212         write(b, 0, b.length);
213     }
214 
putBytes24(byte b[])215     void putBytes24(byte b[]) throws IOException {
216         if (b == null) {
217             putInt24(0);
218             return;
219         } else {
220             checkOverflow(b.length, Record.OVERFLOW_OF_INT24);
221         }
222         putInt24(b.length);
223         write(b, 0, b.length);
224     }
225 
checkOverflow(int length, int overflow)226     private void checkOverflow(int length, int overflow) {
227         if (length >= overflow) {
228             // internal_error alert will be triggered
229             throw new RuntimeException(
230                     "Field length overflow, the field length (" +
231                     length + ") should be less than " + overflow);
232         }
233     }
234 }
235