1package sql
2
3import (
4	"database/sql"
5	"fmt"
6	"runtime"
7	"sync"
8
9	"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/dialers/mysql"
10	"github.com/pkg/errors"
11
12	"repodiff/constants"
13)
14
15var mux sync.Mutex
16var db *sql.DB
17
18type handleRowFn func(*sql.Rows)
19
20func newDBConnectionPool() (*sql.DB, error) {
21	cfg := mysql.Cfg(
22		constants.GetConfigVar("GCP_DB_INSTANCE_CONNECTION_NAME"),
23		constants.GetConfigVar("GCP_DB_USER"),
24		constants.GetConfigVar("GCP_DB_PASSWORD"),
25	)
26	cfg.DBName = constants.GetConfigVar("GCP_DB_NAME")
27	return mysql.DialCfg(cfg)
28}
29
30func maxParallelism() int {
31	maxProcs := runtime.GOMAXPROCS(0)
32	numCPU := runtime.NumCPU()
33	if maxProcs < numCPU {
34		return maxProcs
35	}
36	return numCPU
37}
38
39func GetDBConnectionPool() (*sql.DB, error) {
40	if db != nil {
41		return db, nil
42	}
43	mux.Lock()
44	defer mux.Unlock()
45
46	// check, lock, check; redundant check for thread safety
47	if db != nil {
48		return db, nil
49	}
50	var err error
51	db, err = newDBConnectionPool()
52	if err != nil {
53		return nil, err
54	}
55	connections := maxParallelism()
56
57	// unless explicitly specified, the default connection pool size is unlimited
58	db.SetMaxOpenConns(connections)
59
60	// unless explicitly specified, the default is 0 where idle connections are immediately closed
61	db.SetMaxIdleConns(connections)
62	return db, nil
63}
64
65func SingleTransactionInsert(db *sql.DB, insertQuery string, rowsOfCols [][]interface{}) error {
66	tx, err := db.Begin()
67	if err != nil {
68		return errors.Wrap(err, "Error starting transaction")
69	}
70	stmt, err := tx.Prepare(insertQuery)
71	if err != nil {
72		return errors.Wrap(err, "Error preparing statement")
73	}
74	defer stmt.Close()
75
76	for _, cols := range rowsOfCols {
77		_, err = stmt.Exec(
78			cols...,
79		)
80		if err != nil {
81			tx.Rollback()
82			return errors.Wrap(err, "Error inserting values")
83		}
84	}
85	err = tx.Commit()
86	if err != nil {
87		tx.Rollback()
88		return errors.Wrap(
89			err,
90			"Error committing transaction",
91		)
92	}
93	return nil
94}
95
96func Select(db *sql.DB, rowHandler handleRowFn, query string, args ...interface{}) error {
97	rows, err := db.Query(
98		query,
99		args...,
100	)
101	if err != nil {
102		return err
103	}
104	defer rows.Close()
105
106	for rows.Next() {
107		rowHandler(rows)
108	}
109	if err = rows.Err(); err != nil {
110		return err
111	}
112	return nil
113}
114
115func TruncateTable(db *sql.DB, tableName string) error {
116	_, err := db.Exec(
117		fmt.Sprintf("TRUNCATE TABLE %s", tableName),
118	)
119	return err
120}
121