1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package workspace
16
17import (
18	"fmt"
19	"io"
20	"io/ioutil"
21	"os"
22	"path/filepath"
23	"strings"
24)
25
26type FileCopier struct {
27}
28
29func NewFileCopier() *FileCopier {
30	var f FileCopier
31	return &f
32}
33
34func (f FileCopier) GetIsGitProjectFunc(codebaseDir string, gitProjects []string) func(string) (bool, error) {
35	//Convert the git project list to a set to speed up lookups
36	gitProjectSet := make(map[string]struct{})
37	var exists = struct{}{}
38	for _, project := range gitProjects {
39		gitProjectSet[project] = exists
40	}
41
42	return func(pathToCheck string) (bool, error) {
43		var err error
44		if pathToCheck, err = filepath.Rel(codebaseDir, pathToCheck); err != nil {
45			return false, err
46		}
47		if _, ok := gitProjectSet[pathToCheck]; ok {
48			return true, err
49		}
50		return false, err
51	}
52}
53
54func (f FileCopier) GetContainsGitProjectFunc(codebaseDir string, gitProjects []string) func(string) (bool, error) {
55	//Extract the set of dirs that contain git projects
56	containsGitSet := make(map[string]struct{})
57	var exists = struct{}{}
58	for _, project := range gitProjects {
59		for dir := project; dir != "." && dir != "/"; dir = filepath.Dir(dir) {
60			containsGitSet[dir] = exists
61		}
62	}
63
64	return func(pathToCheck string) (bool, error) {
65		var err error
66		if pathToCheck, err = filepath.Rel(codebaseDir, pathToCheck); err != nil {
67			return false, err
68		}
69		if _, ok := containsGitSet[pathToCheck]; ok {
70			return true, err
71		}
72		return false, err
73	}
74}
75
76//gitProjects is relative to codebaseDir
77func (f FileCopier) Copy(codebaseDir string, gitProjects []string, workspaceDir string) error {
78	isGitProject := f.GetIsGitProjectFunc(codebaseDir, gitProjects)
79	containsGitProject := f.GetContainsGitProjectFunc(codebaseDir, gitProjects)
80
81	return filepath.Walk(codebaseDir,
82		func(path string, info os.FileInfo, err error) error {
83			if err != nil {
84				return err
85			}
86
87			// Copy files
88			if !info.IsDir() {
89				return f.CopyNode(info, codebaseDir, path, workspaceDir)
90			}
91
92			if path == filepath.Clean(codebaseDir) {
93				return nil
94			}
95
96			// Always skip traversal of root repo directories
97			if path == filepath.Join(codebaseDir, ".repo") {
98				return filepath.SkipDir
99			}
100
101			// Skip all git projects
102			var isGitProj bool
103			if isGitProj, err = isGitProject(path); err != nil {
104				return err
105			}
106			if isGitProj {
107				return filepath.SkipDir
108			}
109
110			// Copy over files
111			var containsGitProj bool
112			if containsGitProj, err = containsGitProject(path); err != nil {
113				return err
114			}
115			if !containsGitProj {
116				destPath, err := f.GetDestPath(codebaseDir, path, workspaceDir)
117				if err != nil {
118					return err
119				}
120				if err = f.CopyDirRecursive(info, path, destPath); err != nil {
121					return err
122				}
123				return filepath.SkipDir
124			}
125			return f.CopyNode(info, codebaseDir, path, workspaceDir)
126		})
127}
128
129func (f FileCopier) GetDestPath(codebaseDir, sourcePath, workspaceDir string) (string, error) {
130	if !strings.HasPrefix(sourcePath+"/", codebaseDir+"/") {
131		return "", fmt.Errorf("%s is not contained in %s", sourcePath, codebaseDir)
132	}
133	relPath, err := filepath.Rel(codebaseDir, sourcePath)
134	if err != nil {
135		return "", err
136	}
137	destPath := filepath.Join(workspaceDir, relPath)
138	return destPath, err
139}
140
141// Copy any single file, symlink or dir non-recursively
142// sourcePath must be contained in codebaseDir
143func (f FileCopier) CopyNode(sourceInfo os.FileInfo, codebaseDir, sourcePath, workspaceDir string) error {
144	destPath, err := f.GetDestPath(codebaseDir, sourcePath, workspaceDir)
145	if err != nil {
146		return err
147	}
148	switch {
149	case sourceInfo.Mode()&os.ModeSymlink == os.ModeSymlink:
150		return f.CopySymlink(sourcePath, destPath)
151	case sourceInfo.Mode().IsDir():
152		return f.CopyDirOnly(sourceInfo, destPath)
153	default:
154		return f.CopyFile(sourceInfo, sourcePath, destPath)
155	}
156}
157
158func (f FileCopier) CopySymlink(sourcePath string, destPath string) error {
159	// Skip symlink if it already exists at the destination
160	_, err := os.Lstat(destPath)
161	if err == nil {
162		return nil
163	}
164
165	target, err := os.Readlink(sourcePath)
166	if err != nil {
167		return err
168	}
169
170	return os.Symlink(target, destPath)
171}
172
173// CopyDirOnly copies a directory non-recursively
174// sourcePath must be contained in codebaseDir
175func (f FileCopier) CopyDirOnly(sourceInfo os.FileInfo, destPath string) error {
176	_, err := os.Stat(destPath)
177	if err == nil {
178		// Dir already exists, nothing to do
179		return err
180	} else if os.IsNotExist(err) {
181		return os.Mkdir(destPath, sourceInfo.Mode())
182	}
183	return err
184}
185
186// CopyFile copies a single file
187// sourcePath must be contained in codebaseDir
188func (f FileCopier) CopyFile(sourceInfo os.FileInfo, sourcePath, destPath string) error {
189	//Skip file if it already exists at the destination
190	_, err := os.Lstat(destPath)
191	if err == nil {
192		return nil
193	}
194
195	sourceFile, err := os.Open(sourcePath)
196	if err != nil {
197		return err
198	}
199	defer sourceFile.Close()
200
201	destFile, err := os.Create(destPath)
202	if err != nil {
203		return err
204	}
205	defer destFile.Close()
206
207	_, err = io.Copy(destFile, sourceFile)
208	if err != nil {
209		return err
210	}
211	return os.Chmod(destPath, sourceInfo.Mode())
212}
213
214func (f FileCopier) CopyDirRecursive(sourceInfo os.FileInfo, sourcePath, destPath string) error {
215	if err := f.CopyDirOnly(sourceInfo, destPath); err != nil {
216		return err
217	}
218	childNodes, err := ioutil.ReadDir(sourcePath)
219	if err != nil {
220		return err
221	}
222	for _, childInfo := range childNodes {
223		childSourcePath := filepath.Join(sourcePath, childInfo.Name())
224		childDestPath := filepath.Join(destPath, childInfo.Name())
225		switch {
226		case childInfo.Mode()&os.ModeSymlink == os.ModeSymlink:
227			if err = f.CopySymlink(childSourcePath, childDestPath); err != nil {
228				return err
229			}
230		case childInfo.Mode().IsDir():
231			if err = f.CopyDirRecursive(childInfo, childSourcePath, childDestPath); err != nil {
232				return err
233			}
234		default:
235			if err = f.CopyFile(childInfo, childSourcePath, childDestPath); err != nil {
236				return err
237			}
238		}
239	}
240	return err
241}
242