1// Program setid demonstrates how the to use the cap and/or psx packages to
2// change the uid, gids of a program.
3//
4// A long writeup explaining how to use it in various different ways
5// is available:
6//
7//   https://sites.google.com/site/fullycapable/Home/using-go-to-set-uid-and-gids
8package main
9
10import (
11	"flag"
12	"fmt"
13	"io/ioutil"
14	"log"
15	"strconv"
16	"strings"
17	"syscall"
18	"unsafe"
19
20	"kernel.org/pub/linux/libs/security/libcap/cap"
21	"kernel.org/pub/linux/libs/security/libcap/psx"
22)
23
24var (
25	uid      = flag.Int("uid", -1, "specify a uid with a value other than (euid)")
26	gid      = flag.Int("gid", -1, "specify a gid with a value other than (egid)")
27	drop     = flag.Bool("drop", true, "drop privilege once IDs have been changed")
28	suppl    = flag.String("suppl", "", "comma separated list of groups")
29	withCaps = flag.Bool("caps", true, "raise capabilities to setuid/setgid")
30)
31
32// setIDWithCaps uses the cap.SetUID and cap.SetGroups functions.
33func setIDsWithCaps(setUID, setGID int, gids []int) {
34	if err := cap.SetGroups(setGID, gids...); err != nil {
35		log.Fatalf("group setting failed: %v", err)
36	}
37	if err := cap.SetUID(setUID); err != nil {
38		log.Fatalf("user setting failed: %v", err)
39	}
40}
41
42func main() {
43	flag.Parse()
44
45	showIDs("before", false, syscall.Getuid(), syscall.Getgid())
46
47	gids := splitToInts()
48	setGID := *gid
49	if *gid == -1 {
50		setGID = syscall.Getegid()
51	}
52	setUID := *uid
53	if *uid == -1 {
54		setUID = syscall.Getuid()
55	}
56
57	if *withCaps {
58		setIDsWithCaps(setUID, setGID, gids)
59	} else {
60		if _, _, err := psx.Syscall3(syscall.SYS_SETGID, uintptr(setGID), 0, 0); err != 0 {
61			log.Fatalf("failed to setgid(%d): %v", setGID, err)
62		}
63		if len(gids) != 0 {
64			gids32 := []int32{int32(setGID)}
65			for _, g := range gids {
66				gids32 = append(gids32, int32(g))
67			}
68			if _, _, err := psx.Syscall3(syscall.SYS_SETGROUPS, uintptr(unsafe.Pointer(&gids32[0])), 0, 0); err != 0 {
69				log.Fatalf("failed to setgroups(%d, %v): %v", setGID, gids32, err)
70			}
71		}
72		if _, _, err := psx.Syscall3(syscall.SYS_SETUID, uintptr(setUID), 0, 0); err != 0 {
73			log.Fatalf("failed to setgid(%d): %v", setUID, err)
74		}
75	}
76
77	if *drop {
78		if err := cap.NewSet().SetProc(); err != nil {
79			log.Fatalf("unable to drop privilege: %v", err)
80		}
81	}
82
83	showIDs("after", true, setUID, setGID)
84}
85
86// splitToInts parses a comma separated string to a slice of integers.
87func splitToInts() (ret []int) {
88	if *suppl == "" {
89		return
90	}
91	a := strings.Split(*suppl, ",")
92	for _, s := range a {
93		n, err := strconv.Atoi(s)
94		if err != nil {
95			log.Fatalf("bad supplementary group [%q]: %v", s, err)
96		}
97		ret = append(ret, n)
98	}
99	return
100}
101
102// dumpStatus explores the current process /proc/task/* status files
103// for matching values.
104func dumpStatus(testCase string, validate bool, filter, expect string) bool {
105	fmt.Printf("%s:\n", testCase)
106	var failed bool
107	pid := syscall.Getpid()
108	fs, err := ioutil.ReadDir(fmt.Sprintf("/proc/%d/task", pid))
109	if err != nil {
110		log.Fatal(err)
111	}
112	for _, f := range fs {
113		tf := fmt.Sprintf("/proc/%s/status", f.Name())
114		d, err := ioutil.ReadFile(tf)
115		if err != nil {
116			fmt.Println(tf, err)
117			failed = true
118			continue
119		}
120		lines := strings.Split(string(d), "\n")
121		for _, line := range lines {
122			if strings.HasPrefix(line, filter) {
123				fails := line != expect
124				failure := ""
125				if fails && validate {
126					failed = fails
127					failure = " (bad)"
128				}
129				fmt.Printf("%s %s%s\n", tf, line, failure)
130				break
131			}
132		}
133	}
134	return failed
135}
136
137// showIDs dumps the thread map out of the /proc/<proc>/tasks
138// filesystem to confirm that all of the threads associated with the
139// process have the same uid/gid values. Note, the code does not
140// attempt to validate the supplementary groups at present.
141func showIDs(test string, validate bool, wantUID, wantGID int) {
142	fmt.Printf("%s capability state: %q\n", test, cap.GetProc())
143
144	failed := dumpStatus(test+" gid", validate, "Gid:", fmt.Sprintf("Gid:\t%d\t%d\t%d\t%d", wantGID, wantGID, wantGID, wantGID))
145
146	failed = dumpStatus(test+" uid", validate, "Uid:", fmt.Sprintf("Uid:\t%d\t%d\t%d\t%d", wantUID, wantUID, wantUID, wantUID)) || failed
147
148	if validate && failed {
149		log.Fatal("did not observe desired *id state")
150	}
151}
152