// Copyright 2017 syzkaller project authors. All rights reserved. // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. package rpctype import ( "compress/flate" "fmt" "io" "net" "net/rpc" "os" "time" "github.com/google/syzkaller/pkg/log" ) type RPCServer struct { ln net.Listener s *rpc.Server } func NewRPCServer(addr string, receiver interface{}) (*RPCServer, error) { ln, err := net.Listen("tcp", addr) if err != nil { return nil, fmt.Errorf("failed to listen on %v: %v", addr, err) } s := rpc.NewServer() if err := s.Register(receiver); err != nil { return nil, err } serv := &RPCServer{ ln: ln, s: s, } return serv, nil } func (serv *RPCServer) Serve() { for { conn, err := serv.ln.Accept() if err != nil { log.Logf(0, "failed to accept an rpc connection: %v", err) continue } setupKeepAlive(conn, 10*time.Second) go serv.s.ServeConn(newFlateConn(conn)) } } func (serv *RPCServer) Addr() net.Addr { return serv.ln.Addr() } type RPCClient struct { conn net.Conn c *rpc.Client } func Dial(addr string) (net.Conn, error) { var conn net.Conn var err error if addr == "stdin" { // This is used by vm/gvisor which passes us a unix socket connection in stdin. return net.FileConn(os.Stdin) } if conn, err = net.DialTimeout("tcp", addr, 60*time.Second); err != nil { return nil, err } setupKeepAlive(conn, time.Minute) return conn, nil } func NewRPCClient(addr string) (*RPCClient, error) { conn, err := Dial(addr) if err != nil { return nil, err } cli := &RPCClient{ conn: conn, c: rpc.NewClient(newFlateConn(conn)), } return cli, nil } func (cli *RPCClient) Call(method string, args, reply interface{}) error { // Note: SetDeadline is not implemented on fuchsia, so don't fail on error. cli.conn.SetDeadline(time.Now().Add(5 * 60 * time.Second)) defer cli.conn.SetDeadline(time.Time{}) return cli.c.Call(method, args, reply) } func (cli *RPCClient) Close() { cli.c.Close() } func RPCCall(addr, method string, args, reply interface{}) error { c, err := NewRPCClient(addr) if err != nil { return err } defer c.Close() return c.Call(method, args, reply) } func setupKeepAlive(conn net.Conn, keepAlive time.Duration) { conn.(*net.TCPConn).SetKeepAlive(true) conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) } // flateConn wraps net.Conn in flate.Reader/Writer for compressed traffic. type flateConn struct { r io.ReadCloser w *flate.Writer c io.Closer } func newFlateConn(conn io.ReadWriteCloser) io.ReadWriteCloser { w, err := flate.NewWriter(conn, 9) if err != nil { panic(err) } return &flateConn{ r: flate.NewReader(conn), w: w, c: conn, } } func (fc *flateConn) Read(data []byte) (int, error) { return fc.r.Read(data) } func (fc *flateConn) Write(data []byte) (int, error) { n, err := fc.w.Write(data) if err != nil { return n, err } if err := fc.w.Flush(); err != nil { return n, err } return n, nil } func (fc *flateConn) Close() error { var err0 error if err := fc.r.Close(); err != nil { err0 = err } if err := fc.w.Close(); err != nil { err0 = err } if err := fc.c.Close(); err != nil { err0 = err } return err0 }