1// Copyright 2015 Google Inc. All rights reserved
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package kati
16
17import (
18	"bytes"
19	"errors"
20	"fmt"
21	"io"
22	"regexp"
23	"strconv"
24	"strings"
25
26	"github.com/golang/glog"
27)
28
29var (
30	errEndOfInput = errors.New("unexpected end of input")
31	errNotLiteral = errors.New("valueNum: not literal")
32
33	errUnterminatedVariableReference = errors.New("*** unterminated variable reference.")
34)
35
36type evalWriter interface {
37	io.Writer
38	writeWord([]byte)
39	writeWordString(string)
40	resetSep()
41}
42
43// Value is an interface for value.
44type Value interface {
45	String() string
46	Eval(w evalWriter, ev *Evaluator) error
47	serialize() serializableVar
48	dump(d *dumpbuf)
49}
50
51// literal is literal value.
52type literal string
53
54func (s literal) String() string { return string(s) }
55func (s literal) Eval(w evalWriter, ev *Evaluator) error {
56	io.WriteString(w, string(s))
57	return nil
58}
59func (s literal) serialize() serializableVar {
60	return serializableVar{Type: "literal", V: string(s)}
61}
62func (s literal) dump(d *dumpbuf) {
63	d.Byte(valueTypeLiteral)
64	d.Bytes([]byte(s))
65}
66
67// tmpval is temporary value.
68type tmpval []byte
69
70func (t tmpval) String() string { return string(t) }
71func (t tmpval) Eval(w evalWriter, ev *Evaluator) error {
72	w.Write(t)
73	return nil
74}
75func (t tmpval) Value() []byte { return []byte(t) }
76func (t tmpval) serialize() serializableVar {
77	return serializableVar{Type: "tmpval", V: string(t)}
78}
79func (t tmpval) dump(d *dumpbuf) {
80	d.Byte(valueTypeTmpval)
81	d.Bytes(t)
82}
83
84// expr is a list of values.
85type expr []Value
86
87func (e expr) String() string {
88	var s []string
89	for _, v := range e {
90		s = append(s, v.String())
91	}
92	return strings.Join(s, "")
93}
94
95func (e expr) Eval(w evalWriter, ev *Evaluator) error {
96	for _, v := range e {
97		w.resetSep()
98		err := v.Eval(w, ev)
99		if err != nil {
100			return err
101		}
102	}
103	return nil
104}
105
106func (e expr) serialize() serializableVar {
107	r := serializableVar{Type: "expr"}
108	for _, v := range e {
109		r.Children = append(r.Children, v.serialize())
110	}
111	return r
112}
113func (e expr) dump(d *dumpbuf) {
114	d.Byte(valueTypeExpr)
115	d.Int(len(e))
116	for _, v := range e {
117		v.dump(d)
118	}
119}
120
121func compactExpr(e expr) Value {
122	if len(e) == 1 {
123		return e[0]
124	}
125	// TODO(ukai): concat literal
126	return e
127}
128func toExpr(v Value) expr {
129	if v == nil {
130		return nil
131	}
132	if e, ok := v.(expr); ok {
133		return e
134	}
135	return expr{v}
136}
137
138// varref is variable reference. e.g. ${foo}.
139type varref struct {
140	varname Value
141	paren   byte
142}
143
144func (v *varref) String() string {
145	varname := v.varname.String()
146	if len(varname) == 1 && v.paren == 0 {
147		return fmt.Sprintf("$%s", varname)
148	}
149	paren := v.paren
150	if paren == 0 {
151		paren = '{'
152	}
153	return fmt.Sprintf("$%c%s%c", paren, varname, closeParen(paren))
154}
155
156func (v *varref) Eval(w evalWriter, ev *Evaluator) error {
157	te := traceEvent.begin("var", v, traceEventMain)
158	buf := newEbuf()
159	err := v.varname.Eval(buf, ev)
160	if err != nil {
161		return err
162	}
163	vv := ev.LookupVar(buf.String())
164	buf.release()
165	err = vv.Eval(w, ev)
166	if err != nil {
167		return err
168	}
169	traceEvent.end(te)
170	return nil
171}
172
173func (v *varref) serialize() serializableVar {
174	return serializableVar{
175		Type:     "varref",
176		V:        string(v.paren),
177		Children: []serializableVar{v.varname.serialize()},
178	}
179}
180func (v *varref) dump(d *dumpbuf) {
181	d.Byte(valueTypeVarref)
182	d.Byte(v.paren)
183	v.varname.dump(d)
184}
185
186// paramref is parameter reference e.g. $1.
187type paramref int
188
189func (p paramref) String() string {
190	return fmt.Sprintf("$%d", int(p))
191}
192
193func (p paramref) Eval(w evalWriter, ev *Evaluator) error {
194	te := traceEvent.begin("param", p, traceEventMain)
195	n := int(p)
196	if n < len(ev.paramVars) {
197		err := ev.paramVars[n].Eval(w, ev)
198		if err != nil {
199			return err
200		}
201	} else {
202		vv := ev.LookupVar(fmt.Sprintf("%d", n))
203		err := vv.Eval(w, ev)
204		if err != nil {
205			return err
206		}
207	}
208	traceEvent.end(te)
209	return nil
210}
211
212func (p paramref) serialize() serializableVar {
213	return serializableVar{Type: "paramref", V: strconv.Itoa(int(p))}
214}
215
216func (p paramref) dump(d *dumpbuf) {
217	d.Byte(valueTypeParamref)
218	d.Int(int(p))
219}
220
221// varsubst is variable substitutaion. e.g. ${var:pat=subst}.
222type varsubst struct {
223	varname Value
224	pat     Value
225	subst   Value
226	paren   byte
227}
228
229func (v varsubst) String() string {
230	paren := v.paren
231	if paren == 0 {
232		paren = '{'
233	}
234	return fmt.Sprintf("$%c%s:%s=%s%c", paren, v.varname, v.pat, v.subst, closeParen(paren))
235}
236
237func (v varsubst) Eval(w evalWriter, ev *Evaluator) error {
238	te := traceEvent.begin("varsubst", v, traceEventMain)
239	buf := newEbuf()
240	params, err := ev.args(buf, v.varname, v.pat, v.subst)
241	if err != nil {
242		return err
243	}
244	vname := string(params[0])
245	pat := string(params[1])
246	subst := string(params[2])
247	buf.Reset()
248	vv := ev.LookupVar(vname)
249	err = vv.Eval(buf, ev)
250	if err != nil {
251		return err
252	}
253	vals := splitSpaces(buf.String())
254	buf.release()
255	space := false
256	for _, val := range vals {
257		if space {
258			io.WriteString(w, " ")
259		}
260		io.WriteString(w, substRef(pat, subst, val))
261		space = true
262	}
263	traceEvent.end(te)
264	return nil
265}
266
267func (v varsubst) serialize() serializableVar {
268	return serializableVar{
269		Type: "varsubst",
270		V:    string(v.paren),
271		Children: []serializableVar{
272			v.varname.serialize(),
273			v.pat.serialize(),
274			v.subst.serialize(),
275		},
276	}
277}
278
279func (v varsubst) dump(d *dumpbuf) {
280	d.Byte(valueTypeVarsubst)
281	d.Byte(v.paren)
282	v.varname.dump(d)
283	v.pat.dump(d)
284	v.subst.dump(d)
285}
286
287func str(buf []byte, alloc bool) Value {
288	if alloc {
289		return literal(string(buf))
290	}
291	return tmpval(buf)
292}
293
294func appendStr(exp expr, buf []byte, alloc bool) expr {
295	if len(buf) == 0 {
296		return exp
297	}
298	if len(exp) == 0 {
299		return append(exp, str(buf, alloc))
300	}
301	switch v := exp[len(exp)-1].(type) {
302	case literal:
303		v += literal(string(buf))
304		exp[len(exp)-1] = v
305		return exp
306	case tmpval:
307		v = append(v, buf...)
308		exp[len(exp)-1] = v
309		return exp
310	}
311	return append(exp, str(buf, alloc))
312}
313
314func valueNum(v Value) (int, error) {
315	switch v := v.(type) {
316	case literal, tmpval:
317		n, err := strconv.ParseInt(v.String(), 10, 64)
318		return int(n), err
319	}
320	return 0, errNotLiteral
321}
322
323type parseOp struct {
324	// alloc indicates text will be allocated as literal (string)
325	alloc bool
326
327	// matchParen matches parenthesis.
328	// note: required for func arg
329	matchParen bool
330}
331
332// parseExpr parses expression in `in` until it finds any byte in term.
333// if term is nil, it will parse to end of input.
334// if term is not nil, and it reaches to end of input, return error.
335// it returns parsed value, and parsed length `n`, so in[n-1] is any byte of
336// term, and in[n:] is next input.
337func parseExpr(in, term []byte, op parseOp) (Value, int, error) {
338	var exp expr
339	b := 0
340	i := 0
341	var saveParen byte
342	parenDepth := 0
343Loop:
344	for i < len(in) {
345		ch := in[i]
346		if term != nil && bytes.IndexByte(term, ch) >= 0 {
347			break Loop
348		}
349		switch ch {
350		case '$':
351			if i+1 >= len(in) {
352				break Loop
353			}
354			if in[i+1] == '$' {
355				exp = appendStr(exp, in[b:i+1], op.alloc)
356				i += 2
357				b = i
358				continue
359			}
360			if bytes.IndexByte(term, in[i+1]) >= 0 {
361				exp = appendStr(exp, in[b:i], op.alloc)
362				exp = append(exp, &varref{varname: literal("")})
363				i++
364				b = i
365				break Loop
366			}
367			exp = appendStr(exp, in[b:i], op.alloc)
368			v, n, err := parseDollar(in[i:], op.alloc)
369			if err != nil {
370				return nil, 0, err
371			}
372			i += n
373			b = i
374			exp = append(exp, v)
375			continue
376		case '(', '{':
377			if !op.matchParen {
378				break
379			}
380			cp := closeParen(ch)
381			if i := bytes.IndexByte(term, cp); i >= 0 {
382				parenDepth++
383				saveParen = cp
384				term[i] = 0
385			} else if cp == saveParen {
386				parenDepth++
387			}
388		case saveParen:
389			if !op.matchParen {
390				break
391			}
392			parenDepth--
393			if parenDepth == 0 {
394				i := bytes.IndexByte(term, 0)
395				term[i] = saveParen
396				saveParen = 0
397			}
398		}
399		i++
400	}
401	exp = appendStr(exp, in[b:i], op.alloc)
402	if i == len(in) && term != nil {
403		glog.Warningf("parse: unexpected end of input: %q %d [%q]", in, i, term)
404		return exp, i, errEndOfInput
405	}
406	return compactExpr(exp), i, nil
407}
408
409func closeParen(ch byte) byte {
410	switch ch {
411	case '(':
412		return ')'
413	case '{':
414		return '}'
415	}
416	return 0
417}
418
419// parseDollar parses
420//   $(func expr[, expr...])  # func = literal SP
421//   $(expr:expr=expr)
422//   $(expr)
423//   $x
424// it returns parsed value and parsed length.
425func parseDollar(in []byte, alloc bool) (Value, int, error) {
426	if len(in) <= 1 {
427		return nil, 0, errors.New("empty expr")
428	}
429	if in[0] != '$' {
430		return nil, 0, errors.New("should starts with $")
431	}
432	if in[1] == '$' {
433		return nil, 0, errors.New("should handle $$ as literal $")
434	}
435	oparen := in[1]
436	paren := closeParen(oparen)
437	if paren == 0 {
438		// $x case.
439		if in[1] >= '0' && in[1] <= '9' {
440			return paramref(in[1] - '0'), 2, nil
441		}
442		return &varref{varname: str(in[1:2], alloc)}, 2, nil
443	}
444	term := []byte{paren, ':', ' '}
445	var varname expr
446	i := 2
447	op := parseOp{alloc: alloc}
448Again:
449	for {
450		e, n, err := parseExpr(in[i:], term, op)
451		if err != nil {
452			if err == errEndOfInput {
453				// unmatched_paren2.mk
454				varname = append(varname, toExpr(e)...)
455				if len(varname) > 0 {
456					for i, vn := range varname {
457						if vr, ok := vn.(*varref); ok {
458							if vr.paren == oparen {
459								varname = varname[:i+1]
460								varname[i] = expr{literal(fmt.Sprintf("$%c", oparen)), vr.varname}
461								return &varref{varname: varname, paren: oparen}, i + 1 + n + 1, nil
462							}
463						}
464					}
465				}
466				return nil, 0, errUnterminatedVariableReference
467			}
468			return nil, 0, err
469		}
470		varname = append(varname, toExpr(e)...)
471		i += n
472		switch in[i] {
473		case paren:
474			// ${expr}
475			vname := compactExpr(varname)
476			n, err := valueNum(vname)
477			if err == nil {
478				// ${n}
479				return paramref(n), i + 1, nil
480			}
481			return &varref{varname: vname, paren: oparen}, i + 1, nil
482		case ' ':
483			// ${e ...}
484			switch token := e.(type) {
485			case literal, tmpval:
486				funcName := intern(token.String())
487				if f, ok := funcMap[funcName]; ok {
488					return parseFunc(f(), in, i+1, term[:1], funcName, op.alloc)
489				}
490			}
491			term = term[:2] // drop ' '
492			continue Again
493		case ':':
494			// ${varname:...}
495			colon := in[i : i+1]
496			var vterm []byte
497			vterm = append(vterm, term[:2]...)
498			vterm[1] = '=' // term={paren, '='}.
499			e, n, err := parseExpr(in[i+1:], vterm, op)
500			if err != nil {
501				return nil, 0, err
502			}
503			i += 1 + n
504			if in[i] == paren {
505				varname = appendStr(varname, colon, op.alloc)
506				return &varref{varname: varname, paren: oparen}, i + 1, nil
507			}
508			// ${varname:xx=...}
509			pat := e
510			subst, n, err := parseExpr(in[i+1:], term[:1], op)
511			if err != nil {
512				return nil, 0, err
513			}
514			i += 1 + n
515			// ${first:pat=e}
516			return varsubst{
517				varname: compactExpr(varname),
518				pat:     pat,
519				subst:   subst,
520				paren:   oparen,
521			}, i + 1, nil
522		default:
523			return nil, 0, fmt.Errorf("unexpected char %c at %d in %q", in[i], i, string(in))
524		}
525	}
526}
527
528// skipSpaces skips spaces at front of `in` before any bytes in term.
529// in[n] will be the first non white space in in.
530func skipSpaces(in, term []byte) int {
531	for i := 0; i < len(in); i++ {
532		if bytes.IndexByte(term, in[i]) >= 0 {
533			return i
534		}
535		switch in[i] {
536		case ' ', '\t':
537		default:
538			return i
539		}
540	}
541	return len(in)
542}
543
544// trimLiteralSpace trims literal space around v.
545func trimLiteralSpace(v Value) Value {
546	switch v := v.(type) {
547	case literal:
548		return literal(strings.TrimSpace(string(v)))
549	case tmpval:
550		b := bytes.TrimSpace([]byte(v))
551		if len(b) == 0 {
552			return literal("")
553		}
554		return tmpval(b)
555	case expr:
556		if len(v) == 0 {
557			return v
558		}
559		switch s := v[0].(type) {
560		case literal, tmpval:
561			t := trimLiteralSpace(s)
562			if t == literal("") {
563				v = v[1:]
564			} else {
565				v[0] = t
566			}
567		}
568		switch s := v[len(v)-1].(type) {
569		case literal, tmpval:
570			t := trimLiteralSpace(s)
571			if t == literal("") {
572				v = v[:len(v)-1]
573			} else {
574				v[len(v)-1] = t
575			}
576		}
577		return compactExpr(v)
578	}
579	return v
580}
581
582// concatLine concatinates line with "\\\n" in function expression.
583// TODO(ukai): less alloc?
584func concatLine(v Value) Value {
585	switch v := v.(type) {
586	case literal:
587		for {
588			s := string(v)
589			i := strings.Index(s, "\\\n")
590			if i < 0 {
591				return v
592			}
593			v = literal(s[:i] + strings.TrimLeft(s[i+2:], " \t"))
594		}
595	case tmpval:
596		for {
597			b := []byte(v)
598			i := bytes.Index(b, []byte{'\\', '\n'})
599			if i < 0 {
600				return v
601			}
602			var buf bytes.Buffer
603			buf.Write(b[:i])
604			buf.Write(bytes.TrimLeft(b[i+2:], " \t"))
605			v = tmpval(buf.Bytes())
606		}
607	case expr:
608		for i := range v {
609			switch vv := v[i].(type) {
610			case literal, tmpval:
611				v[i] = concatLine(vv)
612			}
613		}
614		return v
615	}
616	return v
617}
618
619// parseFunc parses function arguments from in[s:] for f.
620// in[0] is '$' and in[s] is space just after func name.
621// in[:n] will be "${func args...}"
622func parseFunc(f mkFunc, in []byte, s int, term []byte, funcName string, alloc bool) (Value, int, error) {
623	f.AddArg(str(in[1:s-1], alloc))
624	arity := f.Arity()
625	term = append(term, ',')
626	i := skipSpaces(in[s:], term)
627	i = s + i
628	if i == len(in) {
629		return f, i, nil
630	}
631	narg := 1
632	op := parseOp{alloc: alloc, matchParen: true}
633	for {
634		if arity != 0 && narg >= arity {
635			// final arguments.
636			term = term[:1] // drop ','
637		}
638		v, n, err := parseExpr(in[i:], term, op)
639		if err != nil {
640			if err == errEndOfInput {
641				return nil, 0, fmt.Errorf("*** unterminated call to function `%s': missing `)'.", funcName)
642			}
643			return nil, 0, err
644		}
645		v = concatLine(v)
646		// TODO(ukai): do this in funcIf, funcAnd, or funcOr's compactor?
647		if (narg == 1 && funcName == "if") || funcName == "and" || funcName == "or" {
648			v = trimLiteralSpace(v)
649		}
650		f.AddArg(v)
651		i += n
652		narg++
653		if in[i] == term[0] {
654			i++
655			break
656		}
657		i++ // should be ','
658		if i == len(in) {
659			break
660		}
661	}
662	var fv Value
663	fv = f
664	if compactor, ok := f.(compactor); ok {
665		fv = compactor.Compact()
666	}
667	if EvalStatsFlag || traceEvent.enabled() {
668		fv = funcstats{
669			Value: fv,
670			str:   fv.String(),
671		}
672
673	}
674	return fv, i, nil
675}
676
677type compactor interface {
678	Compact() Value
679}
680
681type funcstats struct {
682	Value
683	str string
684}
685
686func (f funcstats) Eval(w evalWriter, ev *Evaluator) error {
687	te := traceEvent.begin("func", literal(f.str), traceEventMain)
688	err := f.Value.Eval(w, ev)
689	if err != nil {
690		return err
691	}
692	// TODO(ukai): per functype?
693	traceEvent.end(te)
694	return nil
695}
696
697type matcherValue struct{}
698
699func (m matcherValue) Eval(w evalWriter, ev *Evaluator) error {
700	return fmt.Errorf("couldn't eval matcher")
701}
702func (m matcherValue) serialize() serializableVar {
703	return serializableVar{Type: ""}
704}
705
706func (m matcherValue) dump(d *dumpbuf) {
707	d.err = fmt.Errorf("couldn't dump matcher")
708}
709
710type matchVarref struct{ matcherValue }
711
712func (m matchVarref) String() string { return "$(match-any)" }
713
714type literalRE struct {
715	matcherValue
716	*regexp.Regexp
717}
718
719func mustLiteralRE(s string) literalRE {
720	return literalRE{
721		Regexp: regexp.MustCompile(s),
722	}
723}
724
725func (r literalRE) String() string { return r.Regexp.String() }
726
727func matchValue(exp, pat Value) bool {
728	switch pat := pat.(type) {
729	case literal:
730		return literal(exp.String()) == pat
731	}
732	// TODO: other type match?
733	return false
734}
735
736func matchExpr(exp, pat expr) ([]Value, bool) {
737	if len(exp) != len(pat) {
738		return nil, false
739	}
740	var mv matchVarref
741	var matches []Value
742	for i := range exp {
743		if pat[i] == mv {
744			switch exp[i].(type) {
745			case paramref, *varref:
746				matches = append(matches, exp[i])
747				continue
748			}
749			return nil, false
750		}
751		if patre, ok := pat[i].(literalRE); ok {
752			re := patre.Regexp
753			m := re.FindStringSubmatch(exp[i].String())
754			if m == nil {
755				return nil, false
756			}
757			for _, sm := range m[1:] {
758				matches = append(matches, literal(sm))
759			}
760			continue
761		}
762		if !matchValue(exp[i], pat[i]) {
763			return nil, false
764		}
765	}
766	return matches, true
767}
768