1// Copyright 2017 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4package rpctype
5
6import (
7	"compress/flate"
8	"fmt"
9	"io"
10	"net"
11	"net/rpc"
12	"os"
13	"time"
14
15	"github.com/google/syzkaller/pkg/log"
16)
17
18type RPCServer struct {
19	ln net.Listener
20	s  *rpc.Server
21}
22
23func NewRPCServer(addr string, receiver interface{}) (*RPCServer, error) {
24	ln, err := net.Listen("tcp", addr)
25	if err != nil {
26		return nil, fmt.Errorf("failed to listen on %v: %v", addr, err)
27	}
28	s := rpc.NewServer()
29	if err := s.Register(receiver); err != nil {
30		return nil, err
31	}
32	serv := &RPCServer{
33		ln: ln,
34		s:  s,
35	}
36	return serv, nil
37}
38
39func (serv *RPCServer) Serve() {
40	for {
41		conn, err := serv.ln.Accept()
42		if err != nil {
43			log.Logf(0, "failed to accept an rpc connection: %v", err)
44			continue
45		}
46		setupKeepAlive(conn, 10*time.Second)
47		go serv.s.ServeConn(newFlateConn(conn))
48	}
49}
50
51func (serv *RPCServer) Addr() net.Addr {
52	return serv.ln.Addr()
53}
54
55type RPCClient struct {
56	conn net.Conn
57	c    *rpc.Client
58}
59
60func Dial(addr string) (net.Conn, error) {
61	var conn net.Conn
62	var err error
63	if addr == "stdin" {
64		// This is used by vm/gvisor which passes us a unix socket connection in stdin.
65		return net.FileConn(os.Stdin)
66	}
67	if conn, err = net.DialTimeout("tcp", addr, 60*time.Second); err != nil {
68		return nil, err
69	}
70	setupKeepAlive(conn, time.Minute)
71	return conn, nil
72}
73
74func NewRPCClient(addr string) (*RPCClient, error) {
75	conn, err := Dial(addr)
76	if err != nil {
77		return nil, err
78	}
79	cli := &RPCClient{
80		conn: conn,
81		c:    rpc.NewClient(newFlateConn(conn)),
82	}
83	return cli, nil
84}
85
86func (cli *RPCClient) Call(method string, args, reply interface{}) error {
87	// Note: SetDeadline is not implemented on fuchsia, so don't fail on error.
88	cli.conn.SetDeadline(time.Now().Add(5 * 60 * time.Second))
89	defer cli.conn.SetDeadline(time.Time{})
90	return cli.c.Call(method, args, reply)
91}
92
93func (cli *RPCClient) Close() {
94	cli.c.Close()
95}
96
97func RPCCall(addr, method string, args, reply interface{}) error {
98	c, err := NewRPCClient(addr)
99	if err != nil {
100		return err
101	}
102	defer c.Close()
103	return c.Call(method, args, reply)
104}
105
106func setupKeepAlive(conn net.Conn, keepAlive time.Duration) {
107	conn.(*net.TCPConn).SetKeepAlive(true)
108	conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive)
109}
110
111// flateConn wraps net.Conn in flate.Reader/Writer for compressed traffic.
112type flateConn struct {
113	r io.ReadCloser
114	w *flate.Writer
115	c io.Closer
116}
117
118func newFlateConn(conn io.ReadWriteCloser) io.ReadWriteCloser {
119	w, err := flate.NewWriter(conn, 9)
120	if err != nil {
121		panic(err)
122	}
123	return &flateConn{
124		r: flate.NewReader(conn),
125		w: w,
126		c: conn,
127	}
128}
129
130func (fc *flateConn) Read(data []byte) (int, error) {
131	return fc.r.Read(data)
132}
133
134func (fc *flateConn) Write(data []byte) (int, error) {
135	n, err := fc.w.Write(data)
136	if err != nil {
137		return n, err
138	}
139	if err := fc.w.Flush(); err != nil {
140		return n, err
141	}
142	return n, nil
143}
144
145func (fc *flateConn) Close() error {
146	var err0 error
147	if err := fc.r.Close(); err != nil {
148		err0 = err
149	}
150	if err := fc.w.Close(); err != nil {
151		err0 = err
152	}
153	if err := fc.c.Close(); err != nil {
154		err0 = err
155	}
156	return err0
157}
158