diff --git a/db.go b/db.go index 3255d025..a14f7789 100644 --- a/db.go +++ b/db.go @@ -4,11 +4,8 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" - "strings" - "time" - "go.uber.org/multierr" + "time" ) // DB interface is a contract that supported by this library. @@ -60,40 +57,6 @@ type sqlDB struct { stmtLoadBalancer StmtLoadBalancer } -// OpenMultiPrimary concurrently opens each underlying db connection -// both primaryDataSourceNames and readOnlyDataSourceNames must be a semi-comma separated list of DSNs -// primaryDataSourceNames will be used as the RW-database(primary) -// and readOnlyDataSourceNames as RO databases (replicas). -func OpenMultiPrimary(driverName, primaryDataSourceNames, readOnlyDataSourceNames string) (res DB, err error) { - primaryConns := strings.Split(primaryDataSourceNames, ";") - readOnlyConns := strings.Split(readOnlyDataSourceNames, ";") - - if len(primaryConns) == 0 { - return nil, errors.New("require primary data source name") - } - - opt := defaultOption() - db := &sqlDB{ - replicas: make([]*sql.DB, len(readOnlyConns)), - primaries: make([]*sql.DB, len(primaryConns)), - loadBalancer: opt.DBLB, - stmtLoadBalancer: opt.StmtLB, - } - - db.totalConnection = len(primaryConns) + len(readOnlyConns) - err = doParallely(db.totalConnection, func(i int) (err error) { - if i < len(primaryConns) { - db.primaries[0], err = sql.Open(driverName, primaryConns[i]) - return err - } - roIndex := i - len(primaryConns) - db.replicas[roIndex], err = sql.Open(driverName, readOnlyConns[roIndex]) - return err - }) - - return db, err -} - // PrimaryDBs return all the active primary DB func (db *sqlDB) PrimaryDBs() []*sql.DB { return db.primaries @@ -138,7 +101,7 @@ func (db *sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, err // The args are for any placeholder parameters in the query. // Exec uses the RW-database as the underlying db connection func (db *sqlDB) Exec(query string, args ...interface{}) (sql.Result, error) { - return db.ReadWrite().Exec(query, args...) + return db.ExecContext(context.Background(), query, args...) } // ExecContext executes a query without returning any rows. @@ -151,13 +114,8 @@ func (db *sqlDB) ExecContext(ctx context.Context, query string, args ...interfac // Ping verifies if a connection to each physical database is still alive, // establishing a connection if necessary. func (db *sqlDB) Ping() error { - errPrimaries := doParallely(len(db.primaries), func(i int) error { - return db.primaries[i].Ping() - }) - errReplicas := doParallely(len(db.replicas), func(i int) error { - return db.replicas[i].Ping() - }) - return multierr.Combine(errPrimaries, errReplicas) + return db.PingContext(context.Background()) + } // PingContext verifies if a connection to each physical database is still @@ -175,32 +133,7 @@ func (db *sqlDB) PingContext(ctx context.Context) error { // Prepare creates a prepared statement for later queries or executions // on each physical database, concurrently. func (db *sqlDB) Prepare(query string) (_stmt Stmt, err error) { - roStmts := make([]*sql.Stmt, len(db.replicas)) - primaryStmts := make([]*sql.Stmt, len(db.primaries)) - - errPrimaries := doParallely(len(db.primaries), func(i int) (err error) { - primaryStmts[i], err = db.primaries[i].Prepare(query) - return - }) - errReplicas := doParallely(len(db.replicas), func(i int) (err error) { - roStmts[i], err = db.replicas[i].Prepare(query) - return err - }) - - err = multierr.Combine(errPrimaries, errReplicas) - - if err != nil { - return - } - - _stmt = &stmt{ - db: db, - loadBalancer: db.stmtLoadBalancer, - primaryStmts: primaryStmts, - replicaStmts: roStmts, - } - - return + return db.PrepareContext(context.Background(), query) } // PrepareContext creates a prepared statement for later queries or executions @@ -240,7 +173,7 @@ func (db *sqlDB) PrepareContext(ctx context.Context, query string) (_stmt Stmt, // The args are for any placeholder parameters in the query. // Query uses a radonly db as the physical db. func (db *sqlDB) Query(query string, args ...interface{}) (*sql.Rows, error) { - return db.ReadOnly().Query(query, args...) + return db.QueryContext(context.Background(), query, args...) } // QueryContext executes a query that returns rows, typically a SELECT. @@ -255,7 +188,7 @@ func (db *sqlDB) QueryContext(ctx context.Context, query string, args ...interfa // Errors are deferred until Row's Scan method is called. // QueryRow uses a radonly db as the physical db. func (db *sqlDB) QueryRow(query string, args ...interface{}) *sql.Row { - return db.ReadOnly().QueryRow(query, args...) + return db.QueryRowContext(context.Background(), query, args...) } // QueryRowContext executes a query that is expected to return at most one row. diff --git a/db_test.go b/db_test.go index b5e8875e..bd688f09 100644 --- a/db_test.go +++ b/db_test.go @@ -3,12 +3,28 @@ package dbresolver import ( "context" "database/sql" - "testing" - "github.com/DATA-DOG/go-sqlmock" + "testing" ) func TestMultiWrite(t *testing.T) { + + loadBalancerPolices := []LoadBalancerPolicy{ + RoundRobinLB, + RandomLB, + } + + retrieveLoadBalancer := func() (loadBalancerPolicy LoadBalancerPolicy) { + loadBalancerPolicy = loadBalancerPolices[0] + loadBalancerPolices = loadBalancerPolices[1:] + return + } + +BEGIN_TEST: + loadBalancerPolicy := retrieveLoadBalancer() + + t.Logf("LoadBalancer-%s", loadBalancerPolicy) + testCases := [][2]uint{ {1, 0}, {1, 1}, @@ -33,10 +49,13 @@ func TestMultiWrite(t *testing.T) { return int(testCase[0]), int(testCase[1]) } -BEGIN: +BEGIN_TEST_CASE: if len(testCases) == 0 { - return + if len(loadBalancerPolices) == 0 { + return + } + goto BEGIN_TEST } noOfPrimaries, noOfReplicas := retrieveTestCase() @@ -48,6 +67,7 @@ BEGIN: mockReplicas := make([]sqlmock.Sqlmock, noOfReplicas) for i := 0; i < noOfPrimaries; i++ { + db, mock, err := createMock() if err != nil { @@ -59,9 +79,11 @@ BEGIN: primaries[i] = db mockPimaries[i] = mock + } for i := 0; i < noOfReplicas; i++ { + db, mock, err := createMock() if err != nil { @@ -75,37 +97,43 @@ BEGIN: mockReplicas[i] = mock } - resolver := New(WithPrimaryDBs(primaries...), WithReplicaDBs(replicas...)).(*sqlDB) + resolver := New(WithPrimaryDBs(primaries...), WithReplicaDBs(replicas...), WithLoadBalancer(loadBalancerPolicy)).(*sqlDB) t.Run("primary dbs", func(t *testing.T) { + for i := 0; i < noOfPrimaries*5; i++ { robin := resolver.loadBalancer.predict(noOfPrimaries) mock := mockPimaries[robin] - switch i % 5 { + t.Log("case - ", i%4) + + switch i % 4 { + case 0: query := "SET timezone TO 'Asia/Tokyo'" - expected := mock.ExpectExec(query) - _, _ = resolver.Exec(query) - t.Log("exec", expected.String()) + mock.ExpectExec(query) + resolver.Exec(query) + t.Log("exec") case 1: query := "SET timezone TO 'Asia/Tokyo'" mock.ExpectExec(query) - _, _ = resolver.ExecContext(context.TODO(), query) + resolver.ExecContext(context.TODO(), query) t.Log("exec context") case 2: mock.ExpectBegin() - _, _ = resolver.Begin() + resolver.Begin() t.Log("begin") - case 4: + case 3: mock.ExpectBegin() - _, _ = resolver.BeginTx(context.TODO(), &sql.TxOptions{ + resolver.BeginTx(context.TODO(), &sql.TxOptions{ Isolation: sql.LevelDefault, ReadOnly: false, }) t.Log("begin transaction") - } + default: + t.Fatal("developer needs to work on the tests") + } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } @@ -113,36 +141,42 @@ BEGIN: }) t.Run("replica dbs", func(t *testing.T) { + for i := 0; i < noOfReplicas*5; i++ { + robin := resolver.loadBalancer.predict(noOfReplicas) mock := mockReplicas[robin] - switch i % 5 { + t.Log("case -", i%4) + + switch i % 4 { + case 0: query := "select 1'" mock.ExpectQuery(query) - res, _ := resolver.Query(query) - _ = res + resolver.Query(query) t.Log("query") case 1: - query := "select 1'" + query := "select 'row'" mock.ExpectQuery(query) - _ = resolver.QueryRow(query) + resolver.QueryRow(query) t.Log("query row") case 2: - query := "select 1'" + query := "select 'query-ctx' " mock.ExpectQuery(query) - res, _ := resolver.QueryContext(context.TODO(), query) - _ = res + resolver.QueryContext(context.TODO(), query) t.Log("query context") - case 4: - query := "select 1'" + case 3: + query := "select 'row'" mock.ExpectQuery(query) - _ = resolver.QueryRowContext(context.TODO(), query) + resolver.QueryRowContext(context.TODO(), query) t.Log("query row context") + default: + t.Fatal("developer needs to work on the tests") + } if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) + t.Errorf("expect failed %s", err) } } }) @@ -178,7 +212,8 @@ BEGIN: mock.ExpectExec(query) - _, _ = stmt.Exec() + stmt.Exec() + }) t.Run("ping", func(t *testing.T) { @@ -203,11 +238,11 @@ BEGIN: err := resolver.Ping() if err != nil { - t.Errorf("got %v, want %v", err, nil) + t.Errorf("ping failed %s", err) } err = resolver.PingContext(context.TODO()) if err != nil { - t.Errorf("got %v, want %v", err, nil) + t.Errorf("ping failed %s", err) } }) @@ -233,7 +268,8 @@ BEGIN: t.Logf("%dP%dR", noOfPrimaries, noOfReplicas) }) - goto BEGIN + goto BEGIN_TEST_CASE + } func createMock() (db *sql.DB, mock sqlmock.Sqlmock, err error) { diff --git a/examples/example_wrap_dbs_test.go b/examples/example_wrap_dbs_test.go index 593145ce..b3b0394f 100644 --- a/examples/example_wrap_dbs_test.go +++ b/examples/example_wrap_dbs_test.go @@ -54,5 +54,5 @@ func ExampleNew() { log.Print("go error when executing the query to the DB", err) } _ = connectionDB.QueryRowContext(context.Background(), "SELECT * FROM book WHERE id=$1") // will use replicaReadOnlyDB - // Output : + // Output: } diff --git a/go.mod b/go.mod index c4209515..3b9f2982 100644 --- a/go.mod +++ b/go.mod @@ -2,11 +2,17 @@ module github.com/bxcodec/dbresolver/v2 go 1.19 -require github.com/lib/pq v1.10.6 +require ( + github.com/lib/pq v1.10.6 + github.com/mattn/go-sqlite3 v1.14.14 +) require ( github.com/DATA-DOG/go-sqlmock v1.5.0 go.uber.org/multierr v1.8.0 ) -require go.uber.org/atomic v1.10.0 // indirect +require ( + github.com/golang-jwt/jwt/v4 v4.4.3 // indirect + go.uber.org/atomic v1.10.0 // indirect +) diff --git a/go.sum b/go.sum index 6dfe116b..9d1d19fe 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,12 @@ github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= +github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/helper.go b/helper.go index 7e4d662d..6c8cac14 100644 --- a/helper.go +++ b/helper.go @@ -22,7 +22,7 @@ func doParallely(n int, fn func(i int) error) error { close(errors) }(wg) - arrErrs := []error{} + var arrErrs []error for err := range errors { if err != nil { arrErrs = append(arrErrs, err) diff --git a/loadbalancer.go b/loadbalancer.go index 4a2bc335..2acffdbb 100644 --- a/loadbalancer.go +++ b/loadbalancer.go @@ -3,7 +3,6 @@ package dbresolver import ( "database/sql" "math/rand" - "sync" "sync/atomic" "time" ) @@ -22,35 +21,32 @@ type LoadBalancer[T DBConnection] interface { // RandomLoadBalancer represent for Random LB policy type RandomLoadBalancer[T DBConnection] struct { - randomInt int - mu sync.Mutex + randInt chan int } // RandomLoadBalancer return the LB policy name -func (lb *RandomLoadBalancer[T]) Name() LoadBalancerPolicy { +func (lb RandomLoadBalancer[T]) Name() LoadBalancerPolicy { return RandomLB } // Resolve return the resolved option for Random LB -func (lb *RandomLoadBalancer[T]) Resolve(dbs []T) T { - if lb.randomInt == -1 { +func (lb RandomLoadBalancer[T]) Resolve(dbs []T) T { + if len(lb.randInt) == 0 { lb.predict(len(dbs)) } - randomInt := lb.randomInt - lb.mu.Lock() - lb.randomInt = -1 - lb.mu.Unlock() + + randomInt := <-lb.randInt + //log.Println("consumed") return dbs[randomInt] } -func (lb *RandomLoadBalancer[T]) predict(n int) int { +func (lb RandomLoadBalancer[T]) predict(n int) int { rand.Seed(time.Now().UnixNano()) max := n - 1 min := 0 idx := rand.Intn(max-min+1) + min - lb.mu.Lock() - lb.randomInt = idx - lb.mu.Unlock() + lb.randInt <- idx + //log.Println("predicted") return idx } diff --git a/options.go b/options.go index 39396e0c..29a56f89 100644 --- a/options.go +++ b/options.go @@ -44,8 +44,12 @@ func WithLoadBalancer(lb LoadBalancerPolicy) OptionFunc { opt.DBLB = &RoundRobinLoadBalancer[*sql.DB]{} opt.StmtLB = &RoundRobinLoadBalancer[*sql.Stmt]{} case RandomLB: - opt.DBLB = &RandomLoadBalancer[*sql.DB]{} - opt.StmtLB = &RandomLoadBalancer[*sql.Stmt]{} + opt.DBLB = &RandomLoadBalancer[*sql.DB]{ + randInt: make(chan int, 1), + } + opt.StmtLB = &RandomLoadBalancer[*sql.Stmt]{ + randInt: make(chan int, 1), + } } } }