1// Copyright 2017 syzkaller project authors. All rights reserved. 2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4package rpctype 5 6import ( 7 "compress/flate" 8 "fmt" 9 "io" 10 "net" 11 "net/rpc" 12 "os" 13 "time" 14 15 "github.com/google/syzkaller/pkg/log" 16) 17 18type RPCServer struct { 19 ln net.Listener 20 s *rpc.Server 21} 22 23func NewRPCServer(addr string, receiver interface{}) (*RPCServer, error) { 24 ln, err := net.Listen("tcp", addr) 25 if err != nil { 26 return nil, fmt.Errorf("failed to listen on %v: %v", addr, err) 27 } 28 s := rpc.NewServer() 29 if err := s.Register(receiver); err != nil { 30 return nil, err 31 } 32 serv := &RPCServer{ 33 ln: ln, 34 s: s, 35 } 36 return serv, nil 37} 38 39func (serv *RPCServer) Serve() { 40 for { 41 conn, err := serv.ln.Accept() 42 if err != nil { 43 log.Logf(0, "failed to accept an rpc connection: %v", err) 44 continue 45 } 46 setupKeepAlive(conn, 10*time.Second) 47 go serv.s.ServeConn(newFlateConn(conn)) 48 } 49} 50 51func (serv *RPCServer) Addr() net.Addr { 52 return serv.ln.Addr() 53} 54 55type RPCClient struct { 56 conn net.Conn 57 c *rpc.Client 58} 59 60func Dial(addr string) (net.Conn, error) { 61 var conn net.Conn 62 var err error 63 if addr == "stdin" { 64 // This is used by vm/gvisor which passes us a unix socket connection in stdin. 65 return net.FileConn(os.Stdin) 66 } 67 if conn, err = net.DialTimeout("tcp", addr, 60*time.Second); err != nil { 68 return nil, err 69 } 70 setupKeepAlive(conn, time.Minute) 71 return conn, nil 72} 73 74func NewRPCClient(addr string) (*RPCClient, error) { 75 conn, err := Dial(addr) 76 if err != nil { 77 return nil, err 78 } 79 cli := &RPCClient{ 80 conn: conn, 81 c: rpc.NewClient(newFlateConn(conn)), 82 } 83 return cli, nil 84} 85 86func (cli *RPCClient) Call(method string, args, reply interface{}) error { 87 // Note: SetDeadline is not implemented on fuchsia, so don't fail on error. 88 cli.conn.SetDeadline(time.Now().Add(5 * 60 * time.Second)) 89 defer cli.conn.SetDeadline(time.Time{}) 90 return cli.c.Call(method, args, reply) 91} 92 93func (cli *RPCClient) Close() { 94 cli.c.Close() 95} 96 97func RPCCall(addr, method string, args, reply interface{}) error { 98 c, err := NewRPCClient(addr) 99 if err != nil { 100 return err 101 } 102 defer c.Close() 103 return c.Call(method, args, reply) 104} 105 106func setupKeepAlive(conn net.Conn, keepAlive time.Duration) { 107 conn.(*net.TCPConn).SetKeepAlive(true) 108 conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) 109} 110 111// flateConn wraps net.Conn in flate.Reader/Writer for compressed traffic. 112type flateConn struct { 113 r io.ReadCloser 114 w *flate.Writer 115 c io.Closer 116} 117 118func newFlateConn(conn io.ReadWriteCloser) io.ReadWriteCloser { 119 w, err := flate.NewWriter(conn, 9) 120 if err != nil { 121 panic(err) 122 } 123 return &flateConn{ 124 r: flate.NewReader(conn), 125 w: w, 126 c: conn, 127 } 128} 129 130func (fc *flateConn) Read(data []byte) (int, error) { 131 return fc.r.Read(data) 132} 133 134func (fc *flateConn) Write(data []byte) (int, error) { 135 n, err := fc.w.Write(data) 136 if err != nil { 137 return n, err 138 } 139 if err := fc.w.Flush(); err != nil { 140 return n, err 141 } 142 return n, nil 143} 144 145func (fc *flateConn) Close() error { 146 var err0 error 147 if err := fc.r.Close(); err != nil { 148 err0 = err 149 } 150 if err := fc.w.Close(); err != nil { 151 err0 = err 152 } 153 if err := fc.c.Close(); err != nil { 154 err0 = err 155 } 156 return err0 157} 158