diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/AUTHORS golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/AUTHORS --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/AUTHORS 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/AUTHORS 2019-11-02 13:15:23.000000000 +0000 @@ -96,3 +96,19 @@ Nayef Ghattas MichaƂ Matczuk Ben Krebsbach +Vivian Mathews +Sascha Steinbiss +Seth Rosenblum +Javier Zunzunegui +Luke Hines +Zhixin Wen +Chang Liu +Ingo Oeser +Luke Hines +Jacob Greenleaf +Alex Lourie ; +Marco Cadetg +Karl Matthias +Thomas Meson +Martin Sucha ; +Pavel Buchinchik diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/batch_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/batch_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/batch_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/batch_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,4 +1,4 @@ -// +build all integration +// +build all cassandra package gocql @@ -9,12 +9,15 @@ func TestBatch_Errors(t *testing.T) { if *flagProto == 1 { - t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } session := createSession(t) defer session.Close() + if session.cfg.ProtoVersion < protoVersion2 { + t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") + } + if err := createTable(session, `CREATE TABLE gocql_test.batch_errors (id int primary key, val inet)`); err != nil { t.Fatal(err) } @@ -27,13 +30,13 @@ } func TestBatch_WithTimestamp(t *testing.T) { - if *flagProto < protoVersion3 { - t.Skip("Batch timestamps are only available on protocol >= 3") - } - session := createSession(t) defer session.Close() + if session.cfg.ProtoVersion < protoVersion3 { + t.Skip("Batch timestamps are only available on protocol >= 3") + } + if err := createTable(session, `CREATE TABLE gocql_test.batch_ts (id int primary key, val text)`); err != nil { t.Fatal(err) } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cass1batch_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cass1batch_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cass1batch_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cass1batch_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,4 +1,4 @@ -// +build all integration +// +build all cassandra package gocql diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cassandra_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cassandra_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cassandra_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cassandra_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,10 +1,12 @@ -// +build all integration +// +build all cassandra package gocql import ( "bytes" "context" + "errors" + "fmt" "io" "math" "math/big" @@ -17,62 +19,9 @@ "time" "unicode" - "gopkg.in/inf.v0" + inf "gopkg.in/inf.v0" ) -// TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections -func TestAuthentication(t *testing.T) { - - if *flagProto < 2 { - t.Skip("Authentication is not supported with protocol < 2") - } - - if !*flagRunAuthTest { - t.Skip("Authentication is not configured in the target cluster") - } - - cluster := createCluster() - - cluster.Authenticator = PasswordAuthenticator{ - Username: "cassandra", - Password: "cassandra", - } - - session, err := cluster.CreateSession() - - if err != nil { - t.Fatalf("Authentication error: %s", err) - } - - session.Close() -} - -//TestRingDiscovery makes sure that you can autodiscover other cluster members when you seed a cluster config with just one node -func TestRingDiscovery(t *testing.T) { - cluster := createCluster() - cluster.Hosts = clusterHosts[:1] - - session := createSessionFromCluster(cluster, t) - defer session.Close() - - if *clusterSize > 1 { - // wait for autodiscovery to update the pool with the list of known hosts - time.Sleep(*flagAutoWait) - } - - session.pool.mu.RLock() - defer session.pool.mu.RUnlock() - size := len(session.pool.hostConnPools) - - if *clusterSize != size { - for p, pool := range session.pool.hostConnPools { - t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.ConnectAddress().String()) - - } - t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size) - } -} - func TestEmptyHosts(t *testing.T) { cluster := createCluster() cluster.Hosts = nil @@ -83,6 +32,7 @@ } func TestInvalidPeerEntry(t *testing.T) { + t.Skip("dont mutate system tables, rewrite this to test what we mean to test") session := createSession(t) // rack, release_version, schema_version, tokens are all null @@ -155,14 +105,15 @@ } buf := &bytes.Buffer{} - trace := NewTraceWriter(session, buf) - + trace := &traceWriter{session: session, w: buf} if err := session.Query(`INSERT INTO trace (id) VALUES (?)`, 42).Trace(trace).Exec(); err != nil { t.Fatal("insert:", err) } else if buf.Len() == 0 { t.Fatal("insert: failed to obtain any tracing") } + trace.mu.Lock() buf.Reset() + trace.mu.Unlock() var value int if err := session.Query(`SELECT id FROM trace WHERE id = ?`, 42).Trace(trace).Scan(&value); err != nil { @@ -175,7 +126,9 @@ // also works from session tracer session.SetTrace(trace) + trace.mu.Lock() buf.Reset() + trace.mu.Unlock() if err := session.Query(`SELECT id FROM trace WHERE id = ?`, 42).Scan(&value); err != nil { t.Fatal("select:", err) } @@ -184,6 +137,151 @@ } } +func TestObserve(t *testing.T) { + session := createSession(t) + defer session.Close() + + if err := createTable(session, `CREATE TABLE gocql_test.observe (id int primary key)`); err != nil { + t.Fatal("create:", err) + } + + var ( + observedErr error + observedKeyspace string + observedStmt string + ) + + const keyspace = "gocql_test" + + resetObserved := func() { + observedErr = errors.New("placeholder only") // used to distinguish err=nil cases + observedKeyspace = "" + observedStmt = "" + } + + observer := funcQueryObserver(func(ctx context.Context, o ObservedQuery) { + observedKeyspace = o.Keyspace + observedStmt = o.Statement + observedErr = o.Err + }) + + // select before inserted, will error but the reporting is err=nil as the query is valid + resetObserved() + var value int + if err := session.Query(`SELECT id FROM observe WHERE id = ?`, 43).Observer(observer).Scan(&value); err == nil { + t.Fatal("select: expected error") + } else if observedErr != nil { + t.Fatalf("select: observed error expected nil, got %q", observedErr) + } else if observedKeyspace != keyspace { + t.Fatal("select: unexpected observed keyspace", observedKeyspace) + } else if observedStmt != `SELECT id FROM observe WHERE id = ?` { + t.Fatal("select: unexpected observed stmt", observedStmt) + } + + resetObserved() + if err := session.Query(`INSERT INTO observe (id) VALUES (?)`, 42).Observer(observer).Exec(); err != nil { + t.Fatal("insert:", err) + } else if observedErr != nil { + t.Fatal("insert:", observedErr) + } else if observedKeyspace != keyspace { + t.Fatal("insert: unexpected observed keyspace", observedKeyspace) + } else if observedStmt != `INSERT INTO observe (id) VALUES (?)` { + t.Fatal("insert: unexpected observed stmt", observedStmt) + } + + resetObserved() + value = 0 + if err := session.Query(`SELECT id FROM observe WHERE id = ?`, 42).Observer(observer).Scan(&value); err != nil { + t.Fatal("select:", err) + } else if value != 42 { + t.Fatalf("value: expected %d, got %d", 42, value) + } else if observedErr != nil { + t.Fatal("select:", observedErr) + } else if observedKeyspace != keyspace { + t.Fatal("select: unexpected observed keyspace", observedKeyspace) + } else if observedStmt != `SELECT id FROM observe WHERE id = ?` { + t.Fatal("select: unexpected observed stmt", observedStmt) + } + + // also works from session observer + resetObserved() + oSession := createSession(t, func(config *ClusterConfig) { config.QueryObserver = observer }) + if err := oSession.Query(`SELECT id FROM observe WHERE id = ?`, 42).Scan(&value); err != nil { + t.Fatal("select:", err) + } else if observedErr != nil { + t.Fatal("select:", err) + } else if observedKeyspace != keyspace { + t.Fatal("select: unexpected observed keyspace", observedKeyspace) + } else if observedStmt != `SELECT id FROM observe WHERE id = ?` { + t.Fatal("select: unexpected observed stmt", observedStmt) + } + + // reports errors when the query is poorly formed + resetObserved() + value = 0 + if err := session.Query(`SELECT id FROM unknown_table WHERE id = ?`, 42).Observer(observer).Scan(&value); err == nil { + t.Fatal("select: expecting error") + } else if observedErr == nil { + t.Fatal("select: expecting observed error") + } else if observedKeyspace != keyspace { + t.Fatal("select: unexpected observed keyspace", observedKeyspace) + } else if observedStmt != `SELECT id FROM unknown_table WHERE id = ?` { + t.Fatal("select: unexpected observed stmt", observedStmt) + } +} + +func TestObserve_Pagination(t *testing.T) { + session := createSession(t) + defer session.Close() + + if err := createTable(session, `CREATE TABLE gocql_test.observe2 (id int, PRIMARY KEY (id))`); err != nil { + t.Fatal("create:", err) + } + + var observedRows int + + resetObserved := func() { + observedRows = -1 + } + + observer := funcQueryObserver(func(ctx context.Context, o ObservedQuery) { + observedRows = o.Rows + }) + + // insert 100 entries, relevant for pagination + for i := 0; i < 50; i++ { + if err := session.Query(`INSERT INTO observe2 (id) VALUES (?)`, i).Exec(); err != nil { + t.Fatal("insert:", err) + } + } + + resetObserved() + + // read the 100 entries in paginated entries of size 10. Expecting 5 observations, each with 10 rows + scanner := session.Query(`SELECT id FROM observe2 LIMIT 100`). + Observer(observer). + PageSize(10). + Iter().Scanner() + for i := 0; i < 50; i++ { + if !scanner.Next() { + t.Fatalf("next: should still be true: %d: %v", i, scanner.Err()) + } + if i%10 == 0 { + if observedRows != 10 { + t.Fatalf("next: expecting a paginated query with 10 entries, got: %d (%d)", observedRows, i) + } + } else if observedRows != -1 { + t.Fatalf("next: not expecting paginated query (-1 entries), got: %d", observedRows) + } + + resetObserved() + } + + if scanner.Next() { + t.Fatal("next: no more entries where expected") + } +} + func TestPaging(t *testing.T) { session := createSession(t) defer session.Close() @@ -215,6 +313,50 @@ } } +func TestPagingWithBind(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion == 1 { + t.Skip("Paging not supported. Please use Cassandra >= 2.0") + } + + if err := createTable(session, "CREATE TABLE gocql_test.paging_bind (id int, val int, primary key(id,val))"); err != nil { + t.Fatal("create table:", err) + } + for i := 0; i < 100; i++ { + if err := session.Query("INSERT INTO paging_bind (id,val) VALUES (?,?)", 1,i).Exec(); err != nil { + t.Fatal("insert:", err) + } + } + + q := session.Query("SELECT val FROM paging_bind WHERE id = ? AND val < ?",1, 50).PageSize(10) + iter := q.Iter() + var id int + count := 0 + for iter.Scan(&id) { + count++ + } + if err := iter.Close(); err != nil { + t.Fatal("close:", err) + } + if count != 50 { + t.Fatalf("expected %d, got %d", 50, count) + } + + iter = q.Bind(1, 20).Iter() + count = 0 + for iter.Scan(&id) { + count++ + } + if count != 20 { + t.Fatalf("expected %d, got %d", 20, count) + } + if err := iter.Close(); err != nil { + t.Fatal("close:", err) + } +} + func TestCAS(t *testing.T) { cluster := createCluster() cluster.SerialConsistency = LocalSerial @@ -337,6 +479,58 @@ } } +func TestDurationType(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < 5 { + t.Skip("Duration type is not supported. Please use protocol version >= 4 and cassandra version >= 3.11") + } + + if err := createTable(session, `CREATE TABLE gocql_test.duration_table ( + k int primary key, v duration + )`); err != nil { + t.Fatal("create:", err) + } + + durations := []Duration{ + Duration{ + Months: 250, + Days: 500, + Nanoseconds: 300010001, + }, + Duration{ + Months: -250, + Days: -500, + Nanoseconds: -300010001, + }, + Duration{ + Months: 0, + Days: 128, + Nanoseconds: 127, + }, + Duration{ + Months: 0x7FFFFFFF, + Days: 0x7FFFFFFF, + Nanoseconds: 0x7FFFFFFFFFFFFFFF, + }, + } + for _, durationSend := range durations { + if err := session.Query(`INSERT INTO gocql_test.duration_table (k, v) VALUES (1, ?)`, durationSend).Exec(); err != nil { + t.Fatal(err) + } + + var id int + var duration Duration + if err := session.Query(`SELECT k, v FROM gocql_test.duration_table`).Scan(&id, &duration); err != nil { + t.Fatal(err) + } + if duration.Months != durationSend.Months || duration.Days != durationSend.Days || duration.Nanoseconds != durationSend.Nanoseconds { + t.Fatalf("Unexpeted value returned, expected=%v, received=%v", durationSend, duration) + } + } +} + func TestMapScanCAS(t *testing.T) { session := createSession(t) defer session.Close() @@ -391,7 +585,7 @@ t.Fatal("create table:", err) } - batch := NewBatch(LoggedBatch) + batch := session.NewBatch(LoggedBatch) for i := 0; i < 100; i++ { batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i) } @@ -423,9 +617,9 @@ var batch *Batch if session.cfg.ProtoVersion == 2 { - batch = NewBatch(CounterBatch) + batch = session.NewBatch(CounterBatch) } else { - batch = NewBatch(UnloggedBatch) + batch = session.NewBatch(UnloggedBatch) } for i := 0; i < 100; i++ { @@ -464,7 +658,7 @@ t.Fatal("create table:", err) } - batch := NewBatch(LoggedBatch) + batch := session.NewBatch(LoggedBatch) for i := 0; i < 65537; i++ { batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i) } @@ -991,7 +1185,7 @@ } if value != "w00t" { - t.Fatalf("expected %v but got %v", "quux", value) + t.Fatalf("expected %v but got %v", "w00t", value) } } @@ -1204,15 +1398,15 @@ } func TestPrepare_MissingSchemaPrepare(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := createSession(t) conn := getRandomConn(t, s) defer s.Close() - insertQry := &Query{stmt: "INSERT INTO invalidschemaprep (val) VALUES (?)", values: []interface{}{5}, cons: s.cons, - session: s, pageSize: s.pageSize, trace: s.trace, - prefetch: s.prefetch, rt: s.cfg.RetryPolicy} - - if err := conn.executeQuery(insertQry).err; err == nil { + insertQry := s.Query("INSERT INTO invalidschemaprep (val) VALUES (?)", 5) + if err := conn.executeQuery(ctx, insertQry).err; err == nil { t.Fatal("expected error, but got nil.") } @@ -1220,22 +1414,29 @@ t.Fatal("create table:", err) } - if err := conn.executeQuery(insertQry).err; err != nil { + if err := conn.executeQuery(ctx, insertQry).err; err != nil { t.Fatal(err) // unconfigured columnfamily } } func TestPrepare_ReprepareStatement(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + session := createSession(t) defer session.Close() + stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement") query := session.Query(stmt, "bar") - if err := conn.executeQuery(query).Close(); err != nil { + if err := conn.executeQuery(ctx, query).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } func TestPrepare_ReprepareBatch(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + session := createSession(t) defer session.Close() @@ -1246,7 +1447,7 @@ stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch") batch := session.NewBatch(UnloggedBatch) batch.Query(stmt, "bar") - if err := conn.executeBatch(batch).Close(); err != nil { + if err := conn.executeBatch(ctx, batch).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } @@ -1277,6 +1478,7 @@ func TestPrepare_PreparedCacheEviction(t *testing.T) { const maxPrepared = 4 + clusterHosts := getClusterHosts() host := clusterHosts[0] cluster := createCluster() cluster.MaxPreparedStmts = maxPrepared @@ -1643,6 +1845,71 @@ } } +type funcBatchObserver func(context.Context, ObservedBatch) + +func (f funcBatchObserver) ObserveBatch(ctx context.Context, o ObservedBatch) { + f(ctx, o) +} + +func TestBatchObserve(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion == 1 { + t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") + } + + if err := createTable(session, `CREATE TABLE gocql_test.batch_observe_table (id int, other int, PRIMARY KEY (id))`); err != nil { + t.Fatal("create table:", err) + } + + type observation struct { + observedErr error + observedKeyspace string + observedStmts []string + } + + var observedBatch *observation + + batch := session.NewBatch(LoggedBatch) + batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) { + if observedBatch != nil { + t.Fatal("batch observe called more than once") + } + + observedBatch = &observation{ + observedKeyspace: o.Keyspace, + observedStmts: o.Statements, + observedErr: o.Err, + } + })) + for i := 0; i < 100; i++ { + // hard coding 'i' into one of the values for better testing of observation + batch.Query(fmt.Sprintf(`INSERT INTO batch_observe_table (id,other) VALUES (?,%d)`, i), i) + } + + if err := session.ExecuteBatch(batch); err != nil { + t.Fatal("execute batch:", err) + } + if observedBatch == nil { + t.Fatal("batch observation has not been called") + } + if len(observedBatch.observedStmts) != 100 { + t.Fatal("expecting 100 observed statements, got", len(observedBatch.observedStmts)) + } + if observedBatch.observedErr != nil { + t.Fatal("not expecting to observe an error", observedBatch.observedErr) + } + if observedBatch.observedKeyspace != "gocql_test" { + t.Fatalf("expecting keyspace 'gocql_test', got %q", observedBatch.observedKeyspace) + } + for i, stmt := range observedBatch.observedStmts { + if stmt != fmt.Sprintf(`INSERT INTO batch_observe_table (id,other) VALUES (?,%d)`, i) { + t.Fatal("unexpected query", stmt) + } + } +} + //TestNilInQuery tests to see that a nil value passed to a query is handled by Cassandra //TODO validate the nil value by reading back the nil. Need to fix Unmarshalling. func TestNilInQuery(t *testing.T) { @@ -1927,6 +2194,165 @@ } } +func TestViewMetadata(t *testing.T) { + session := createSession(t) + defer session.Close() + createViews(t, session) + + views, err := getViewsMetadata(session, "gocql_test") + if err != nil { + t.Fatalf("failed to query view metadata with err: %v", err) + } + if views == nil { + t.Fatal("failed to query view metadata, nil returned") + } + + if len(views) != 1 { + t.Fatal("expected one view") + } + + textType := TypeText + if flagCassVersion.Before(3, 0, 0) { + textType = TypeVarchar + } + + expectedView := ViewMetadata{ + Keyspace: "gocql_test", + Name: "basicview", + FieldNames: []string{"birthday", "nationality", "weight", "height"}, + FieldTypes: []TypeInfo{ + NativeType{typ: TypeTimestamp}, + NativeType{typ: textType}, + NativeType{typ: textType}, + NativeType{typ: textType}, + }, + } + + if !reflect.DeepEqual(views[0], expectedView) { + t.Fatalf("view is %+v, but expected %+v", views[0], expectedView) + } +} + +func TestAggregateMetadata(t *testing.T) { + session := createSession(t) + defer session.Close() + createAggregate(t, session) + + aggregates, err := getAggregatesMetadata(session, "gocql_test") + if err != nil { + t.Fatalf("failed to query aggregate metadata with err: %v", err) + } + if aggregates == nil { + t.Fatal("failed to query aggregate metadata, nil returned") + } + if len(aggregates) != 1 { + t.Fatal("expected only a single aggregate") + } + aggregate := aggregates[0] + + expectedAggregrate := AggregateMetadata{ + Keyspace: "gocql_test", + Name: "average", + ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}}, + InitCond: "(0, 0)", + ReturnType: NativeType{typ: TypeDouble}, + StateType: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeBigInt}, + }, + }, + stateFunc: "avgstate", + finalFunc: "avgfinal", + } + + // In this case cassandra is returning a blob + if flagCassVersion.Before(3, 0, 0) { + expectedAggregrate.InitCond = string([]byte{0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0}) + } + + if !reflect.DeepEqual(aggregate, expectedAggregrate) { + t.Fatalf("aggregate is %+v, but expected %+v", aggregate, expectedAggregrate) + } +} + +func TestFunctionMetadata(t *testing.T) { + session := createSession(t) + defer session.Close() + createFunctions(t, session) + + functions, err := getFunctionsMetadata(session, "gocql_test") + if err != nil { + t.Fatalf("failed to query function metadata with err: %v", err) + } + if functions == nil { + t.Fatal("failed to query function metadata, nil returned") + } + if len(functions) != 2 { + t.Fatal("expected two functions") + } + avgState := functions[1] + avgFinal := functions[0] + + avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;" + expectedAvgState := FunctionMetadata{ + Keyspace: "gocql_test", + Name: "avgstate", + ArgumentTypes: []TypeInfo{ + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeBigInt}, + }, + }, + NativeType{typ: TypeInt}, + }, + ArgumentNames: []string{"state", "val"}, + ReturnType: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeBigInt}, + }, + }, + CalledOnNullInput: true, + Language: "java", + Body: avgStateBody, + } + if !reflect.DeepEqual(avgState, expectedAvgState) { + t.Fatalf("function is %+v, but expected %+v", avgState, expectedAvgState) + } + + finalStateBody := "double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);" + expectedAvgFinal := FunctionMetadata{ + Keyspace: "gocql_test", + Name: "avgfinal", + ArgumentTypes: []TypeInfo{ + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeBigInt}, + }, + }, + }, + ArgumentNames: []string{"state"}, + ReturnType: NativeType{typ: TypeDouble}, + CalledOnNullInput: true, + Language: "java", + Body: finalStateBody, + } + if !reflect.DeepEqual(avgFinal, expectedAvgFinal) { + t.Fatalf("function is %+v, but expected %+v", avgFinal, expectedAvgFinal) + } +} + // Integration test of querying and composition the keyspace metadata func TestKeyspaceMetadata(t *testing.T) { session := createSession(t) @@ -1935,6 +2361,8 @@ if err := createTable(session, "CREATE TABLE gocql_test.test_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } + createAggregate(t, session) + createViews(t, session) if err := session.Query("CREATE INDEX index_metadata ON test_metadata ( third_id )").Exec(); err != nil { t.Fatalf("failed to create index with err: %v", err) @@ -1989,6 +2417,22 @@ // TODO(zariel): scan index info from system_schema t.Errorf("Expected column index named 'index_metadata' but was '%s'", thirdColumn.Index.Name) } + + aggregate, found := keyspaceMetadata.Aggregates["average"] + if !found { + t.Fatal("failed to find the aggreate in metadata") + } + if aggregate.FinalFunc.Name != "avgfinal" { + t.Fatalf("expected final function %s, but got %s", "avgFinal", aggregate.FinalFunc.Name) + } + if aggregate.StateFunc.Name != "avgstate" { + t.Fatalf("expected state function %s, but got %s", "avgstate", aggregate.StateFunc.Name) + } + + _, found = keyspaceMetadata.Views["basicview"] + if !found { + t.Fatal("failed to find the view in metadata") + } } // Integration test of the routing key calculation @@ -2335,58 +2779,36 @@ } } -func TestUDF(t *testing.T) { - session := createSession(t) - defer session.Close() - - if session.cfg.ProtoVersion < 4 { - t.Skip("skipping UDF support on proto < 4") - } - - const query = `CREATE OR REPLACE FUNCTION uniq(state set, val text) - CALLED ON NULL INPUT RETURNS set LANGUAGE java - AS 'state.add(val); return state;'` - - err := session.Query(query).Exec() - if err != nil { - t.Fatal(err) - } -} - func TestDiscoverViaProxy(t *testing.T) { // This (complicated) test tests that when the driver is given an initial host // that is infact a proxy it discovers the rest of the ring behind the proxy // and does not store the proxies address as a host in its connection pool. // See https://github.com/gocql/gocql/issues/481 + clusterHosts := getClusterHosts() proxy, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("unable to create proxy listener: %v", err) } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() var ( - wg sync.WaitGroup mu sync.Mutex proxyConns []net.Conn closed bool ) - go func(wg *sync.WaitGroup) { + go func() { cassandraAddr := JoinHostPort(clusterHosts[0], 9042) cassandra := func() (net.Conn, error) { return net.Dial("tcp", cassandraAddr) } - proxyFn := func(wg *sync.WaitGroup, from, to net.Conn) { - defer wg.Done() - + proxyFn := func(errs chan error, from, to net.Conn) { _, err := io.Copy(to, from) if err != nil { - mu.Lock() - if !closed { - t.Error(err) - } - mu.Unlock() + errs <- err } } @@ -2394,29 +2816,22 @@ // for both the read and write side of the TCP connection to close before // returning. handle := func(conn net.Conn) error { - defer conn.Close() - cass, err := cassandra() if err != nil { return err } - - mu.Lock() - proxyConns = append(proxyConns, cass) - mu.Unlock() - defer cass.Close() - var wg sync.WaitGroup - wg.Add(1) - go proxyFn(&wg, conn, cass) - - wg.Add(1) - go proxyFn(&wg, cass, conn) - - wg.Wait() - - return nil + errs := make(chan error, 2) + go proxyFn(errs, conn, cass) + go proxyFn(errs, cass, conn) + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errs: + return err + } } for { @@ -2436,19 +2851,19 @@ proxyConns = append(proxyConns, conn) mu.Unlock() - wg.Add(1) go func(conn net.Conn) { - defer wg.Done() + defer conn.Close() if err := handle(conn); err != nil { - t.Error(err) - return + mu.Lock() + if !closed { + t.Error(err) + } + mu.Unlock() } }(conn) } - }(&wg) - - defer wg.Wait() + }() proxyAddr := proxy.Addr().String() @@ -2460,11 +2875,6 @@ session := createSessionFromCluster(cluster, t) defer session.Close() - if session.hostSource.localHost.BroadcastAddress() == nil { - t.Skip("Target cluster does not have broadcast_address in system.local.") - goto close - } - // we shouldnt need this but to be safe time.Sleep(1 * time.Second) @@ -2476,7 +2886,6 @@ } session.pool.mu.RUnlock() -close: mu.Lock() closed = true if err := proxy.Close(); err != nil { @@ -2529,7 +2938,7 @@ } func TestSchemaReset(t *testing.T) { - if flagCassVersion.Major == 0 || (flagCassVersion.Before(2, 1, 3)) { + if flagCassVersion.Major == 0 || flagCassVersion.Before(2, 1, 3) { t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", flagCassVersion) } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cluster.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cluster.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cluster.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cluster.go 2019-11-02 13:15:23.000000000 +0000 @@ -44,23 +44,26 @@ // If it is 0 or unset (the default) then the driver will attempt to discover the // highest supported protocol for the cluster. In clusters with nodes of different // versions the protocol selected is not defined (ie, it can be any of the supported in the cluster) - ProtoVersion int - Timeout time.Duration // connection timeout (default: 600ms) - ConnectTimeout time.Duration // initial connection timeout, used during initial dial to server (default: 600ms) - Port int // port (default: 9042) - Keyspace string // initial keyspace (optional) - NumConns int // number of connections per host (default: 2) - Consistency Consistency // default consistency level (default: Quorum) - Compressor Compressor // compression algorithm (default: nil) - Authenticator Authenticator // authenticator (default: nil) - RetryPolicy RetryPolicy // Default retry policy to use for queries (default: 0) - SocketKeepalive time.Duration // The keepalive period to use, enabled if > 0 (default: 0) - MaxPreparedStmts int // Sets the maximum cache size for prepared statements globally for gocql (default: 1000) - MaxRoutingKeyInfo int // Sets the maximum cache size for query info about statements for each session (default: 1000) - PageSize int // Default page size to use for created sessions (default: 5000) - SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset) - SslOpts *SslOptions - DefaultTimestamp bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above) + ProtoVersion int + Timeout time.Duration // connection timeout (default: 600ms) + ConnectTimeout time.Duration // initial connection timeout, used during initial dial to server (default: 600ms) + Port int // port (default: 9042) + Keyspace string // initial keyspace (optional) + NumConns int // number of connections per host (default: 2) + Consistency Consistency // default consistency level (default: Quorum) + Compressor Compressor // compression algorithm (default: nil) + Authenticator Authenticator // authenticator (default: nil) + AuthProvider func(h *HostInfo) (Authenticator, error) // an authenticator factory. Can be used to create alternative authenticators (default: nil) + RetryPolicy RetryPolicy // Default retry policy to use for queries (default: 0) + ConvictionPolicy ConvictionPolicy // Decide whether to mark host as down based on the error and host info (default: SimpleConvictionPolicy) + ReconnectionPolicy ReconnectionPolicy // Default reconnection policy to use for reconnecting before trying to mark host as down (default: see below) + SocketKeepalive time.Duration // The keepalive period to use, enabled if > 0 (default: 0) + MaxPreparedStmts int // Sets the maximum cache size for prepared statements globally for gocql (default: 1000) + MaxRoutingKeyInfo int // Sets the maximum cache size for query info about statements for each session (default: 1000) + PageSize int // Default page size to use for created sessions (default: 5000) + SerialConsistency SerialConsistency // Sets the consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL (default: unset) + SslOpts *SslOptions + DefaultTimestamp bool // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. (default: true, only enabled for protocol 3 and above) // PoolConfig configures the underlying connection pool, allowing the // configuration of host selection and connection selection policies. PoolConfig PoolConfig @@ -69,7 +72,7 @@ ReconnectInterval time.Duration // The maximum amount of time to wait for schema agreement in a cluster after - // receiving a schema change frame. (deault: 60s) + // receiving a schema change frame. (default: 60s) MaxWaitSchemaAgreement time.Duration // HostFilter will filter all incoming events for host, any which don't pass @@ -115,6 +118,32 @@ // See https://issues.apache.org/jira/browse/CASSANDRA-10786 DisableSkipMetadata bool + // QueryObserver will set the provided query observer on all queries created from this session. + // Use it to collect metrics / stats from queries by providing an implementation of QueryObserver. + QueryObserver QueryObserver + + // BatchObserver will set the provided batch observer on all queries created from this session. + // Use it to collect metrics / stats from batch queries by providing an implementation of BatchObserver. + BatchObserver BatchObserver + + // ConnectObserver will set the provided connect observer on all queries + // created from this session. + ConnectObserver ConnectObserver + + // FrameHeaderObserver will set the provided frame header observer on all frames' headers created from this session. + // Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver. + FrameHeaderObserver FrameHeaderObserver + + // Default idempotence for queries + DefaultIdempotence bool + + // The time to wait for frames before flushing the frames connection to Cassandra. + // Can help reduce syscall overhead by making less calls to write. Set to 0 to + // disable. + // + // (default: 200 microseconds) + WriteCoalesceWaitTime time.Duration + // internal config for testing disableControlConn bool } @@ -143,6 +172,9 @@ DefaultTimestamp: true, MaxWaitSchemaAgreement: 60 * time.Second, ReconnectInterval: 60 * time.Second, + ConvictionPolicy: &SimpleConvictionPolicy{}, + ReconnectionPolicy: &ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second}, + WriteCoalesceWaitTime: 200 * time.Microsecond, } return cfg } @@ -168,6 +200,10 @@ return newAddr, newPort } +func (cfg *ClusterConfig) filterHost(host *HostInfo) bool { + return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host)) +} + var ( ErrNoHosts = errors.New("no hosts provided") ErrNoConnectionsStarted = errors.New("no connections were made when creating the session") diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cluster_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cluster_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cluster_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cluster_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,9 +1,10 @@ package gocql import ( + "net" + "reflect" "testing" "time" - "net" ) func TestNewCluster_Defaults(t *testing.T) { @@ -19,6 +20,10 @@ assertEqual(t, "cluster config default timestamp", true, cfg.DefaultTimestamp) assertEqual(t, "cluster config max wait schema agreement", 60*time.Second, cfg.MaxWaitSchemaAgreement) assertEqual(t, "cluster config reconnect interval", 60*time.Second, cfg.ReconnectInterval) + assertTrue(t, "cluster config conviction policy", + reflect.DeepEqual(&SimpleConvictionPolicy{}, cfg.ConvictionPolicy)) + assertTrue(t, "cluster config reconnection policy", + reflect.DeepEqual(&ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second}, cfg.ReconnectionPolicy)) } func TestNewCluster_WithHosts(t *testing.T) { diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/common_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/common_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/common_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/common_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -5,6 +5,7 @@ "fmt" "log" "net" + "reflect" "strings" "sync" "testing" @@ -25,17 +26,18 @@ flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") flagCassVersion cassVersion - clusterHosts []string ) func init() { flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") - flag.Parse() - clusterHosts = strings.Split(*flagCluster, ",") log.SetFlags(log.Lshortfile | log.LstdFlags) } +func getClusterHosts() []string { + return strings.Split(*flagCluster, ",") +} + func addSslOptions(cluster *ClusterConfig) *ClusterConfig { if *flagRunSslTest { cluster.SslOpts = &SslOptions{ @@ -57,7 +59,7 @@ return err } - if err := s.Query(table).RetryPolicy(nil).Exec(); err != nil { + if err := s.Query(table).RetryPolicy(&SimpleRetryPolicy{}).Exec(); err != nil { log.Printf("error creating table table=%q err=%v\n", table, err) return err } @@ -70,7 +72,8 @@ return nil } -func createCluster() *ClusterConfig { +func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { + clusterHosts := getClusterHosts() cluster := NewCluster(clusterHosts...) cluster.ProtoVersion = *flagProto cluster.CQLVersion = *flagCQL @@ -90,10 +93,16 @@ } cluster = addSslOptions(cluster) + + for _, opt := range opts { + opt(cluster) + } + return cluster } func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) { + // TODO: tb.Helper() c := *cluster c.Keyspace = "system" c.Timeout = 30 * time.Second @@ -102,7 +111,6 @@ panic(err) } defer session.Close() - defer tb.Log("closing keyspace session") err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace) if err != nil { @@ -140,8 +148,8 @@ return session } -func createSession(tb testing.TB) *Session { - cluster := createCluster() +func createSession(tb testing.TB, opts ...func(config *ClusterConfig)) *Session { + cluster := createCluster(opts...) return createSessionFromCluster(cluster, tb) } @@ -165,6 +173,50 @@ return session } +func createViews(t *testing.T, session *Session) { + if err := session.Query(` + CREATE TYPE IF NOT EXISTS gocql_test.basicView ( + birthday timestamp, + nationality text, + weight text, + height text); `).Exec(); err != nil { + t.Fatalf("failed to create view with err: %v", err) + } +} + +func createFunctions(t *testing.T, session *Session) { + if err := session.Query(` + CREATE OR REPLACE FUNCTION gocql_test.avgState ( state tuple, val int ) + CALLED ON NULL INPUT + RETURNS tuple + LANGUAGE java AS + $$if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;$$; `).Exec(); err != nil { + t.Fatalf("failed to create function with err: %v", err) + } + if err := session.Query(` + CREATE OR REPLACE FUNCTION gocql_test.avgFinal ( state tuple ) + CALLED ON NULL INPUT + RETURNS double + LANGUAGE java AS + $$double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);$$ + `).Exec(); err != nil { + t.Fatalf("failed to create function with err: %v", err) + } +} + +func createAggregate(t *testing.T, session *Session) { + createFunctions(t, session) + if err := session.Query(` + CREATE OR REPLACE AGGREGATE gocql_test.average(int) + SFUNC avgState + STYPE tuple + FINALFUNC avgFinal + INITCOND (0,0); + `).Exec(); err != nil { + t.Fatalf("failed to create aggregate with err: %v", err) + } +} + func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator { return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) { return newAddr, newPort @@ -172,25 +224,36 @@ } func assertTrue(t *testing.T, description string, value bool) { + t.Helper() if !value { - t.Errorf("expected %s to be true", description) + t.Fatalf("expected %s to be true", description) } } func assertEqual(t *testing.T, description string, expected, actual interface{}) { + t.Helper() if expected != actual { - t.Errorf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) + t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) + } +} + +func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) { + t.Helper() + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) } } func assertNil(t *testing.T, description string, actual interface{}) { + t.Helper() if actual != nil { - t.Errorf("expected %s to be (nil) but was (%+v) instead", description, actual) + t.Fatalf("expected %s to be (nil) but was (%+v) instead", description, actual) } } func assertNotNil(t *testing.T, description string, actual interface{}) { + t.Helper() if actual == nil { - t.Errorf("expected %s not to be (nil)", description) + t.Fatalf("expected %s not to be (nil)", description) } } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/compressor_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/compressor_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/compressor_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/compressor_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,5 +1,3 @@ -// +build all unit - package gocql import ( diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/connectionpool.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/connectionpool.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/connectionpool.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/connectionpool.go 2019-11-02 13:15:23.000000000 +0000 @@ -58,7 +58,8 @@ sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification - return sslOpts.Config, nil + // return clone to avoid race + return sslOpts.Config.Clone(), nil } type policyConnPool struct { @@ -89,14 +90,16 @@ } return &ConnConfig{ - ProtoVersion: cfg.ProtoVersion, - CQLVersion: cfg.CQLVersion, - Timeout: cfg.Timeout, - ConnectTimeout: cfg.ConnectTimeout, - Compressor: cfg.Compressor, - Authenticator: cfg.Authenticator, - Keepalive: cfg.SocketKeepalive, - tlsConfig: tlsConfig, + ProtoVersion: cfg.ProtoVersion, + CQLVersion: cfg.CQLVersion, + Timeout: cfg.Timeout, + ConnectTimeout: cfg.ConnectTimeout, + Compressor: cfg.Compressor, + Authenticator: cfg.Authenticator, + AuthProvider: cfg.AuthProvider, + Keepalive: cfg.SocketKeepalive, + tlsConfig: tlsConfig, + disableCoalesce: tlsConfig != nil, // write coalescing doesn't work with framing on top of TCP like in TLS. }, nil } @@ -339,14 +342,32 @@ //Close the connection pool func (pool *hostConnPool) Close() { pool.mu.Lock() - defer pool.mu.Unlock() if pool.closed { + pool.mu.Unlock() return } pool.closed = true - pool.drainLocked() + // ensure we dont try to reacquire the lock in handleError + // TODO: improve this as the following can happen + // 1) we have locked pool.mu write lock + // 2) conn.Close calls conn.closeWithError(nil) + // 3) conn.closeWithError calls conn.Close() which returns an error + // 4) conn.closeWithError calls pool.HandleError with the error from conn.Close + // 5) pool.HandleError tries to lock pool.mu + // deadlock + + // empty the pool + conns := pool.conns + pool.conns = nil + + pool.mu.Unlock() + + // close the connections + for _, conn := range conns { + conn.Close() + } } // Fill the connection pool @@ -402,7 +423,9 @@ // this is call with the connection pool mutex held, this call will // then recursively try to lock it again. FIXME - go pool.session.handleNodeDown(pool.host.ConnectAddress(), pool.port) + if pool.session.cfg.ConvictionPolicy.AddFailure(err, pool.host) { + go pool.session.handleNodeDown(pool.host.ConnectAddress(), pool.port) + } return } @@ -479,10 +502,10 @@ func (pool *hostConnPool) connect() (err error) { // TODO: provide a more robust connection retry mechanism, we should also // be able to detect hosts that come up by trying to connect to downed ones. - const maxAttempts = 3 // try to connect var conn *Conn - for i := 0; i < maxAttempts; i++ { + reconnectionPolicy := pool.session.cfg.ReconnectionPolicy + for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ { conn, err = pool.session.connect(pool.host, pool) if err == nil { break @@ -494,6 +517,11 @@ break } } + if gocqlDebug { + Logger.Printf("connection failed %q: %v, reconnecting with %T\n", + pool.host.ConnectAddress(), err, reconnectionPolicy) + } + time.Sleep(reconnectionPolicy.GetInterval(i)) } if err != nil { @@ -551,21 +579,3 @@ } } } - -func (pool *hostConnPool) drainLocked() { - // empty the pool - conns := pool.conns - pool.conns = nil - - // close the connections - for _, conn := range conns { - conn.Close() - } -} - -// removes and closes all connections from the pool -func (pool *hostConnPool) drain() { - pool.mu.Lock() - defer pool.mu.Unlock() - pool.drainLocked() -} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/conn.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/conn.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/conn.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/conn.go 2019-11-02 13:15:23.000000000 +0000 @@ -20,7 +20,6 @@ "time" "github.com/gocql/gocql/internal/lru" - "github.com/gocql/gocql/internal/streams" ) @@ -29,6 +28,8 @@ "org.apache.cassandra.auth.PasswordAuthenticator", "com.instaclustr.cassandra.auth.SharedSecretAuthenticator", "com.datastax.bdp.cassandra.auth.DseAuthenticator", + "io.aiven.cassandra.auth.AivenAuthenticator", + "com.ericsson.bss.cassandra.ecaudit.auth.AuditPasswordAuthenticator", } ) @@ -99,8 +100,11 @@ ConnectTimeout time.Duration Compressor Compressor Authenticator Authenticator + AuthProvider func(h *HostInfo) (Authenticator, error) Keepalive time.Duration - tlsConfig *tls.Config + + tlsConfig *tls.Config + disableCoalesce bool } type ConnErrorHandler interface { @@ -116,32 +120,37 @@ // If not zero, how many timeouts we will allow to occur before the connection is closed // and restarted. This is to prevent a single query timeout from killing a connection // which may be serving more queries just fine. -// Default is 10, should not be changed concurrently with queries. -var TimeoutLimit int64 = 10 +// Default is 0, should not be changed concurrently with queries. +// +// depreciated +var TimeoutLimit int64 = 0 // Conn is a single connection to a Cassandra node. It can be used to execute // queries, but users are usually advised to use a more reliable, higher // level API. type Conn struct { - conn net.Conn - r *bufio.Reader - timeout time.Duration - cfg *ConnConfig + conn net.Conn + r *bufio.Reader + w io.Writer + + timeout time.Duration + cfg *ConnConfig + frameObserver FrameHeaderObserver headerBuf [maxFrameHeaderSize]byte streams *streams.IDGenerator - mu sync.RWMutex + mu sync.Mutex calls map[int]*callReq - errorHandler ConnErrorHandler - compressor Compressor - auth Authenticator - addr string + errorHandler ConnErrorHandler + compressor Compressor + auth Authenticator + addr string + version uint8 currentKeyspace string - - host *HostInfo + host *HostInfo session *Session @@ -151,15 +160,42 @@ timeouts int64 } -// Connect establishes a connection to a Cassandra node. -func Connect(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) { +// connect establishes a connection to a Cassandra node using session's connection config. +func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) { + return s.dial(host, s.connCfg, errorHandler) +} + +// dial establishes a connection to a Cassandra node and notifies the session's connectObserver. +func (s *Session) dial(host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { + var obs ObservedConnect + if s.connectObserver != nil { + obs.Host = host + obs.Start = time.Now() + } + + conn, err := s.dialWithoutObserver(host, connConfig, errorHandler) + + if s.connectObserver != nil { + obs.End = time.Now() + obs.Err = err + s.connectObserver.ObserveConnect(obs) + } + + return conn, err +} + +// dialWithoutObserver establishes connection to a Cassandra node. +// +// dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead. +func (s *Session) dialWithoutObserver(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { + ip := host.ConnectAddress() + port := host.port + // TODO(zariel): remove these - if host == nil { - panic("host is nil") - } else if len(host.ConnectAddress()) == 0 { - panic(fmt.Sprintf("host missing connect ip address: %v", host)) - } else if host.Port() == 0 { - panic(fmt.Sprintf("host missing port: %v", host)) + if !validIpAddr(ip) { + panic(fmt.Sprintf("host missing connect ip address: %v", ip)) + } else if port == 0 { + panic(fmt.Sprintf("host missing port: %v", port)) } var ( @@ -170,18 +206,16 @@ dialer := &net.Dialer{ Timeout: cfg.ConnectTimeout, } - - // TODO(zariel): handle ipv6 zone - translatedPeer, translatedPort := session.cfg.translateAddressPort(host.ConnectAddress(), host.Port()) - addr := (&net.TCPAddr{IP: translatedPeer, Port: translatedPort}).String() - //addr := (&net.TCPAddr{IP: host.Peer(), Port: host.Port()}).String() + if cfg.Keepalive > 0 { + dialer.KeepAlive = cfg.Keepalive + } if cfg.tlsConfig != nil { // the TLS config is safe to be reused by connections but it must not // be modified after being used. - conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig) + conn, err = tls.DialWithDialer(dialer, "tcp", host.HostnameAndPort(), cfg.tlsConfig) } else { - conn, err = dialer.Dial("tcp", addr) + conn, err = dialer.Dial("tcp", host.HostnameAndPort()) } if err != nil { @@ -189,24 +223,32 @@ } c := &Conn{ - conn: conn, - r: bufio.NewReader(conn), - cfg: cfg, - calls: make(map[int]*callReq), - timeout: cfg.Timeout, - version: uint8(cfg.ProtoVersion), - addr: conn.RemoteAddr().String(), - errorHandler: errorHandler, - compressor: cfg.Compressor, - auth: cfg.Authenticator, - quit: make(chan struct{}), - session: session, - streams: streams.New(cfg.ProtoVersion), - host: host, + conn: conn, + r: bufio.NewReader(conn), + cfg: cfg, + calls: make(map[int]*callReq), + version: uint8(cfg.ProtoVersion), + addr: conn.RemoteAddr().String(), + errorHandler: errorHandler, + compressor: cfg.Compressor, + quit: make(chan struct{}), + session: s, + streams: streams.New(cfg.ProtoVersion), + host: host, + frameObserver: s.frameObserver, + w: &deadlineWriter{ + w: conn, + timeout: cfg.Timeout, + }, } - if cfg.Keepalive > 0 { - c.setKeepalive(cfg.Keepalive) + if cfg.AuthProvider != nil { + c.auth, err = cfg.AuthProvider(host) + if err != nil { + return nil, err + } + } else { + c.auth = cfg.Authenticator } var ( @@ -214,59 +256,38 @@ cancel func() ) if cfg.ConnectTimeout > 0 { - ctx, cancel = context.WithTimeout(context.Background(), cfg.ConnectTimeout) + ctx, cancel = context.WithTimeout(context.TODO(), cfg.ConnectTimeout) } else { - ctx, cancel = context.WithCancel(context.Background()) + ctx, cancel = context.WithCancel(context.TODO()) } defer cancel() - frameTicker := make(chan struct{}, 1) - startupErr := make(chan error) - go func() { - for range frameTicker { - err := c.recv() - if err != nil { - select { - case startupErr <- err: - case <-ctx.Done(): - } + startup := &startupCoordinator{ + frameTicker: make(chan struct{}), + conn: c, + } - return - } - } - }() + c.timeout = cfg.ConnectTimeout + if err := startup.setupConn(ctx); err != nil { + c.close() + return nil, err + } - go func() { - defer close(frameTicker) - err := c.startup(ctx, frameTicker) - select { - case startupErr <- err: - case <-ctx.Done(): - } - }() + c.timeout = cfg.Timeout - select { - case err := <-startupErr: - if err != nil { - c.Close() - return nil, err - } - case <-ctx.Done(): - c.Close() - return nil, errors.New("gocql: no response to connection startup within timeout") + // dont coalesce startup frames + if s.cfg.WriteCoalesceWaitTime > 0 && !cfg.disableCoalesce { + c.w = newWriteCoalescer(conn, c.timeout, s.cfg.WriteCoalesceWaitTime, c.quit) } go c.serve() + go c.heartBeat() return c, nil } -func (c *Conn) Write(p []byte) (int, error) { - if c.timeout > 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.timeout)) - } - - return c.conn.Write(p) +func (c *Conn) Write(p []byte) (n int, err error) { + return c.w.Write(p) } func (c *Conn) Read(p []byte) (n int, err error) { @@ -292,27 +313,98 @@ return } -func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error { - m := map[string]string{ - "CQL_VERSION": c.cfg.CQLVersion, - } +type startupCoordinator struct { + conn *Conn + frameTicker chan struct{} +} - if c.compressor != nil { - m["COMPRESSION"] = c.compressor.Name() +func (s *startupCoordinator) setupConn(ctx context.Context) error { + startupErr := make(chan error) + go func() { + for range s.frameTicker { + err := s.conn.recv() + if err != nil { + select { + case startupErr <- err: + case <-ctx.Done(): + } + + return + } + } + }() + + go func() { + defer close(s.frameTicker) + err := s.options(ctx) + select { + case startupErr <- err: + case <-ctx.Done(): + } + }() + + select { + case err := <-startupErr: + if err != nil { + return err + } + case <-ctx.Done(): + return errors.New("gocql: no response to connection startup within timeout") } + return nil +} + +func (s *startupCoordinator) write(ctx context.Context, frame frameWriter) (frame, error) { select { - case frameTicker <- struct{}{}: + case s.frameTicker <- struct{}{}: case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() } - framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil) + framer, err := s.conn.exec(ctx, frame, nil) + if err != nil { + return nil, err + } + + return framer.parseFrame() +} + +func (s *startupCoordinator) options(ctx context.Context) error { + frame, err := s.write(ctx, &writeOptionsFrame{}) if err != nil { return err } - frame, err := framer.parseFrame() + supported, ok := frame.(*supportedFrame) + if !ok { + return NewErrProtocol("Unknown type of response to startup frame: %T", frame) + } + + return s.startup(ctx, supported.supported) +} + +func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error { + m := map[string]string{ + "CQL_VERSION": s.conn.cfg.CQLVersion, + } + + if s.conn.compressor != nil { + comp := supported["COMPRESSION"] + name := s.conn.compressor.Name() + for _, compressor := range comp { + if compressor == name { + m["COMPRESSION"] = compressor + break + } + } + + if _, ok := m["COMPRESSION"]; !ok { + s.conn.compressor = nil + } + } + + frame, err := s.write(ctx, &writeStartupFrame{opts: m}) if err != nil { return err } @@ -323,37 +415,25 @@ case *readyFrame: return nil case *authenticateFrame: - return c.authenticateHandshake(ctx, v, frameTicker) + return s.authenticateHandshake(ctx, v) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) } } -func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error { - if c.auth == nil { +func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error { + if s.conn.auth == nil { return fmt.Errorf("authentication required (using %q)", authFrame.class) } - resp, challenger, err := c.auth.Challenge([]byte(authFrame.class)) + resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class)) if err != nil { return err } req := &writeAuthResponseFrame{data: resp} - for { - select { - case frameTicker <- struct{}{}: - case <-ctx.Done(): - return ctx.Err() - } - - framer, err := c.exec(ctx, req, nil) - if err != nil { - return err - } - - frame, err := framer.parseFrame() + frame, err := s.write(ctx, req) if err != nil { return err } @@ -378,8 +458,6 @@ default: return fmt.Errorf("unknown frame response during authentication: %v", v) } - - framerPool.Put(framer) } } @@ -391,7 +469,7 @@ // we should attempt to deliver the error back to the caller if it // exists if err != nil { - c.mu.RLock() + c.mu.Lock() for _, req := range c.calls { // we need to send the error to all waiting queries, put the state // of this conn into not active so that it can not execute any queries. @@ -400,18 +478,25 @@ case <-req.timeout: } } - c.mu.RUnlock() + c.mu.Unlock() } // if error was nil then unblock the quit channel close(c.quit) - c.conn.Close() + cerr := c.close() if err != nil { c.errorHandler.HandleError(c, err, true) + } else if cerr != nil { + // TODO(zariel): is it a good idea to do this? + c.errorHandler.HandleError(c, cerr, true) } } +func (c *Conn) close() error { + return c.conn.Close() +} + func (c *Conn) Close() { c.closeWithError(nil) } @@ -420,15 +505,9 @@ // to execute any queries. This method runs as long as the connection is // open and is therefore usually called in a separate goroutine. func (c *Conn) serve() { - var ( - err error - ) - - for { + var err error + for err == nil { err = c.recv() - if err != nil { - break - } } c.closeWithError(err) @@ -453,6 +532,53 @@ return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame) } +func (c *Conn) heartBeat() { + sleepTime := 1 * time.Second + timer := time.NewTimer(sleepTime) + defer timer.Stop() + + var failures int + + for { + if failures > 5 { + c.closeWithError(fmt.Errorf("gocql: heartbeat failed")) + return + } + + timer.Reset(sleepTime) + + select { + case <-c.quit: + return + case <-timer.C: + } + + framer, err := c.exec(context.Background(), &writeOptionsFrame{}, nil) + if err != nil { + failures++ + continue + } + + resp, err := framer.parseFrame() + if err != nil { + // invalid frame + failures++ + continue + } + + switch resp.(type) { + case *supportedFrame: + // Everything ok + sleepTime = 5 * time.Second + failures = 0 + case error: + // TODO: should we do something here? + default: + panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp)) + } + } +} + func (c *Conn) recv() error { // not safe for concurrent reads @@ -462,14 +588,29 @@ c.conn.SetReadDeadline(time.Time{}) } + headStartTime := time.Now() // were just reading headers over and over and copy bodies head, err := readHeader(c.r, c.headerBuf[:]) + headEndTime := time.Now() if err != nil { return err } + if c.frameObserver != nil { + c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{ + Version: protoVersion(head.version), + Flags: head.flags, + Stream: int16(head.stream), + Opcode: frameOp(head.op), + Length: int32(head.length), + Start: headStartTime, + End: headEndTime, + Host: c.host, + }) + } + if head.stream > c.streams.NumStreams { - return fmt.Errorf("gocql: frame header stream is beyond call exepected bounds: %d", head.stream) + return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream) } else if head.stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently framer := newFramer(c, c, c.compressor, c.version) @@ -485,7 +626,6 @@ if err := framer.readFrame(&head); err != nil { return err } - defer framerPool.Put(framer) frame, err := framer.parseFrame() if err != nil { @@ -497,12 +637,15 @@ } } - c.mu.RLock() + c.mu.Lock() call, ok := c.calls[head.stream] - c.mu.RUnlock() + delete(c.calls, head.stream) + c.mu.Unlock() if call == nil || call.framer == nil || !ok { Logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) return c.discardFrame(head) + } else if head.stream != call.streamID { + panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) } err = call.framer.readFrame(&head) @@ -519,30 +662,19 @@ select { case call.resp <- err: case <-call.timeout: - c.releaseStream(head.stream) + c.releaseStream(call) case <-c.quit: } return nil } -func (c *Conn) releaseStream(stream int) { - c.mu.Lock() - call := c.calls[stream] - if call != nil && stream != call.streamID { - panic(fmt.Sprintf("attempt to release streamID with ivalid stream: %d -> %+v\n", stream, call)) - } else if call == nil { - panic(fmt.Sprintf("releasing a stream not in use: %d", stream)) - } - delete(c.calls, stream) - c.mu.Unlock() - +func (c *Conn) releaseStream(call *callReq) { if call.timer != nil { call.timer.Stop() } - streamPool.Put(call) - c.streams.Clear(stream) + c.streams.Clear(call.streamID) } func (c *Conn) handleTimeout() { @@ -551,16 +683,6 @@ } } -var ( - streamPool = sync.Pool{ - New: func() interface{} { - return &callReq{ - resp: make(chan error), - } - }, - } -) - type callReq struct { // could use a waitgroup but this allows us to do timeouts on the read/send resp chan error @@ -571,6 +693,142 @@ timer *time.Timer } +type deadlineWriter struct { + w interface { + SetWriteDeadline(time.Time) error + io.Writer + } + timeout time.Duration +} + +func (c *deadlineWriter) Write(p []byte) (int, error) { + if c.timeout > 0 { + c.w.SetWriteDeadline(time.Now().Add(c.timeout)) + } + return c.w.Write(p) +} + +func newWriteCoalescer(conn net.Conn, timeout time.Duration, d time.Duration, quit <-chan struct{}) *writeCoalescer { + wc := &writeCoalescer{ + writeCh: make(chan struct{}), // TODO: could this be sync? + cond: sync.NewCond(&sync.Mutex{}), + c: conn, + quit: quit, + timeout: timeout, + } + go wc.writeFlusher(d) + return wc +} + +type writeCoalescer struct { + c net.Conn + + quit <-chan struct{} + writeCh chan struct{} + running bool + + // cond waits for the buffer to be flushed + cond *sync.Cond + buffers net.Buffers + timeout time.Duration + + // result of the write + err error +} + +func (w *writeCoalescer) flushLocked() { + w.running = false + if len(w.buffers) == 0 { + return + } + + if w.timeout > 0 { + w.c.SetWriteDeadline(time.Now().Add(w.timeout)) + } + + // Given we are going to do a fanout n is useless and according to + // the docs WriteTo should return 0 and err or bytes written and + // no error. + _, w.err = w.buffers.WriteTo(w.c) + if w.err != nil { + w.buffers = nil + } + w.cond.Broadcast() +} + +func (w *writeCoalescer) flush() { + w.cond.L.Lock() + w.flushLocked() + w.cond.L.Unlock() +} + +func (w *writeCoalescer) stop() { + w.cond.L.Lock() + defer w.cond.L.Unlock() + + w.flushLocked() + // nil the channel out sends block forever on it + // instead of closing which causes a send on closed channel + // panic. + w.writeCh = nil +} + +func (w *writeCoalescer) Write(p []byte) (int, error) { + w.cond.L.Lock() + + if !w.running { + select { + case w.writeCh <- struct{}{}: + w.running = true + case <-w.quit: + w.cond.L.Unlock() + return 0, io.EOF // TODO: better error here? + } + } + + w.buffers = append(w.buffers, p) + for len(w.buffers) != 0 { + w.cond.Wait() + } + + err := w.err + w.cond.L.Unlock() + + if err != nil { + return 0, err + } + return len(p), nil +} + +func (w *writeCoalescer) writeFlusher(interval time.Duration) { + timer := time.NewTimer(interval) + defer timer.Stop() + defer w.stop() + + if !timer.Stop() { + <-timer.C + } + + for { + // wait for a write to start the flush loop + select { + case <-w.writeCh: + case <-w.quit: + return + } + + timer.Reset(interval) + + select { + case <-w.quit: + return + case <-timer.C: + } + + w.flush() + } +} + func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) { // TODO: move tracer onto conn stream, ok := c.streams.GetStream() @@ -581,21 +839,24 @@ // resp is basically a waiting semaphore protecting the framer framer := newFramer(c, c, c.compressor, c.version) - c.mu.Lock() - call := c.calls[stream] - if call != nil { - c.mu.Unlock() - return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, call.streamID) - } else { - call = streamPool.Get().(*callReq) + call := &callReq{ + framer: framer, + timeout: make(chan struct{}), + streamID: stream, + resp: make(chan error), } - c.calls[stream] = call - call.framer = framer - call.timeout = make(chan struct{}) - call.streamID = stream + c.mu.Lock() + existingCall := c.calls[stream] + if existingCall == nil { + c.calls[stream] = call + } c.mu.Unlock() + if existingCall != nil { + return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, existingCall.streamID) + } + if tracer != nil { framer.trace() } @@ -647,7 +908,7 @@ // this is because the request is still outstanding and we have // been handed another error from another stream which caused the // connection to close. - c.releaseStream(stream) + c.releaseStream(call) } return nil, err } @@ -668,7 +929,7 @@ // // Ensure that the stream is not released if there are potentially outstanding // requests on the stream to prevent nil pointer dereferences in recv(). - defer c.releaseStream(stream) + defer c.releaseStream(call) if v := framer.header.version.version(); v != c.version { return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) @@ -707,6 +968,9 @@ prep := &writePrepareFrame{ statement: stmt, } + if c.version > protoVersion4 { + prep.keyspace = c.currentKeyspace + } framer, err := c.exec(ctx, prep, tracer) if err != nil { @@ -720,6 +984,7 @@ if err != nil { flight.err = err flight.wg.Done() + c.session.stmtsLRU.remove(stmtCacheKey) return nil, err } @@ -751,8 +1016,6 @@ c.session.stmtsLRU.remove(stmtCacheKey) } - framerPool.Put(framer) - return flight.preparedStatment, flight.err } @@ -776,7 +1039,7 @@ return nil } -func (c *Conn) executeQuery(qry *Query) *Iter { +func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params := queryParams{ consistency: qry.cons, } @@ -792,6 +1055,9 @@ if qry.pageSize > 0 { params.pageSize = qry.pageSize } + if c.version > protoVersion4 { + params.keyspace = c.currentKeyspace + } var ( frame frameWriter @@ -801,7 +1067,7 @@ if qry.shouldPrepare() { // Prepare all DML queries. Other queries can not be prepared. var err error - info, err = c.prepareStatement(qry.context, qry.stmt, qry.trace) + info, err = c.prepareStatement(ctx, qry.stmt, qry.trace) if err != nil { return &Iter{err: err} } @@ -840,17 +1106,19 @@ params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) frame = &writeExecuteFrame{ - preparedID: info.id, - params: params, + preparedID: info.id, + params: params, + customPayload: qry.customPayload, } } else { frame = &writeQueryFrame{ - statement: qry.stmt, - params: params, + statement: qry.stmt, + params: params, + customPayload: qry.customPayload, } } - framer, err := c.exec(qry.context, frame, qry.trace) + framer, err := c.exec(ctx, frame, qry.trace) if err != nil { return &Iter{err: err} } @@ -877,7 +1145,7 @@ if params.skipMeta { if info != nil { iter.meta = info.response - iter.meta.pagingState = x.meta.pagingState + iter.meta.pagingState = copyBytes(x.meta.pagingState) } else { return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} } @@ -885,9 +1153,9 @@ iter.meta = x.meta } - if len(x.meta.pagingState) > 0 && !qry.disableAutoPage { + if x.meta.morePages() && !qry.disableAutoPage { iter.next = &nextIter{ - qry: *qry, + qry: qry, pos: int((1 - qry.prefetch) * float64(x.numRows)), } @@ -900,9 +1168,9 @@ return iter case *resultKeyspaceFrame: return &Iter{framer: framer} - case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction: + case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: iter := &Iter{framer: framer} - if err := c.awaitSchemaAgreement(); err != nil { + if err := c.awaitSchemaAgreement(ctx); err != nil { // TODO: should have this behind a flag Logger.Println(err) } @@ -913,7 +1181,7 @@ case *RequestErrUnprepared: stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt) if c.session.stmtsLRU.remove(stmtCacheKey) { - return c.executeQuery(qry) + return c.executeQuery(ctx, qry) } return &Iter{err: x, framer: framer} @@ -973,7 +1241,7 @@ return nil } -func (c *Conn) executeBatch(batch *Batch) *Iter { +func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { if c.version == protoVersion1 { return &Iter{err: ErrUnsupported} } @@ -986,6 +1254,7 @@ serialConsistency: batch.serialCons, defaultTimestamp: batch.defaultTimestamp, defaultTimestampValue: batch.defaultTimestampValue, + customPayload: batch.CustomPayload, } stmts := make(map[string]string, len(batch.Entries)) @@ -993,8 +1262,9 @@ for i := 0; i < n; i++ { entry := &batch.Entries[i] b := &req.statements[i] + if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatement(batch.context, entry.Stmt, nil) + info, err := c.prepareStatement(batch.Context(), entry.Stmt, nil) if err != nil { return &Iter{err: err} } @@ -1037,7 +1307,7 @@ } // TODO: should batch support tracing? - framer, err := c.exec(batch.context, req, nil) + framer, err := c.exec(batch.Context(), req, nil) if err != nil { return &Iter{err: err} } @@ -1049,7 +1319,6 @@ switch x := resp.(type) { case *resultVoidFrame: - framerPool.Put(framer) return &Iter{} case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] @@ -1058,10 +1327,8 @@ c.session.stmtsLRU.remove(key) } - framerPool.Put(framer) - if found { - return c.executeBatch(batch) + return c.executeBatch(ctx, batch) } else { return &Iter{err: x, framer: framer} } @@ -1080,54 +1347,50 @@ } } -func (c *Conn) setKeepalive(d time.Duration) error { - if tc, ok := c.conn.(*net.TCPConn); ok { - err := tc.SetKeepAlivePeriod(d) - if err != nil { - return err - } - - return tc.SetKeepAlive(true) - } - - return nil -} - -func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) { +func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) { q := c.session.Query(statement, values...).Consistency(One) - return c.executeQuery(q) + q.trace = nil + return c.executeQuery(ctx, q) } -func (c *Conn) awaitSchemaAgreement() (err error) { +func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { const ( - peerSchemas = "SELECT schema_version FROM system.peers" + peerSchemas = "SELECT * FROM system.peers" localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" ) var versions map[string]struct{} + var schemaVersion string endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement) for time.Now().Before(endDeadline) { - iter := c.query(peerSchemas) + iter := c.query(ctx, peerSchemas) versions = make(map[string]struct{}) - var schemaVersion string - for iter.Scan(&schemaVersion) { - if schemaVersion == "" { - Logger.Println("skipping peer entry with empty schema_version") + rows, err := iter.SliceMap() + if err != nil { + goto cont + } + + for _, row := range rows { + host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}) + if err != nil { + goto cont + } + if !isValidPeer(host) || host.schemaVersion == "" { + Logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host) continue } - versions[schemaVersion] = struct{}{} - schemaVersion = "" + versions[host.schemaVersion] = struct{}{} } if err = iter.Close(); err != nil { goto cont } - iter = c.query(localSchemas) + iter = c.query(ctx, localSchemas) for iter.Scan(&schemaVersion) { versions[schemaVersion] = struct{}{} schemaVersion = "" @@ -1142,11 +1405,15 @@ } cont: - time.Sleep(200 * time.Millisecond) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(200 * time.Millisecond): + } } if err != nil { - return + return err } schemas := make([]string, 0, len(versions)) @@ -1158,6 +1425,23 @@ return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas) } +func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) { + row, err := c.query(ctx, "SELECT * FROM system.local WHERE key='local'").rowMap() + if err != nil { + return nil, err + } + + port := c.conn.RemoteAddr().(*net.TCPAddr).Port + + // TODO(zariel): avoid doing this here + host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.connectAddress, port: port}) + if err != nil { + return nil, err + } + + return c.session.ring.addOrUpdate(host), nil +} + var ( ErrQueryArgLength = errors.New("gocql: query argument length mismatch") ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/conn_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/conn_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/conn_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/conn_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -6,18 +6,24 @@ package gocql import ( + "bufio" + "bytes" "context" "crypto/tls" "crypto/x509" "fmt" "io" "io/ioutil" + "math/rand" "net" + "os" "strings" "sync" "sync/atomic" "testing" "time" + + "github.com/gocql/gocql/internal/streams" ) const ( @@ -29,6 +35,7 @@ approve("org.apache.cassandra.auth.PasswordAuthenticator"): true, approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator"): true, approve("com.datastax.bdp.cassandra.auth.DseAuthenticator"): true, + approve("io.aiven.cassandra.auth.AivenAuthenticator"): true, approve("com.apache.cassandra.auth.FakeAuthenticator"): false, } for k, v := range tests { @@ -40,8 +47,8 @@ func TestJoinHostPort(t *testing.T) { tests := map[string]string{ - "127.0.0.1:0": JoinHostPort("127.0.0.1", 0), - "127.0.0.1:1": JoinHostPort("127.0.0.1:1", 9142), + "127.0.0.1:0": JoinHostPort("127.0.0.1", 0), + "127.0.0.1:1": JoinHostPort("127.0.0.1:1", 9142), "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:0": JoinHostPort("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 0), "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1": JoinHostPort("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1", 9142), } @@ -52,8 +59,8 @@ } } -func testCluster(addr string, proto protoVersion) *ClusterConfig { - cluster := NewCluster(addr) +func testCluster(proto protoVersion, addresses ...string) *ClusterConfig { + cluster := NewCluster(addresses...) cluster.ProtoVersion = int(proto) cluster.disableControlConn = true return cluster @@ -63,7 +70,7 @@ srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() - cluster := testCluster(srv.Address, defaultProto) + cluster := testCluster(defaultProto, srv.Address) db, err := cluster.CreateSession() if err != nil { t.Fatalf("0x%x: NewCluster: %v", defaultProto, err) @@ -103,7 +110,7 @@ } func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *ClusterConfig { - cluster := testCluster(addr, proto) + cluster := testCluster(proto, addr) sslOpts := &SslOptions{ CaPath: "testdata/pki/ca.crt", EnableHostVerification: false, @@ -124,7 +131,7 @@ srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() - session, err := newTestSession(srv.Address, defaultProto) + session, err := newTestSession(defaultProto, srv.Address) if err != nil { t.Fatalf("0x%x: NewCluster: %v", defaultProto, err) } @@ -136,8 +143,8 @@ } } -func newTestSession(addr string, proto protoVersion) (*Session, error) { - return testCluster(addr, proto).CreateSession() +func newTestSession(proto protoVersion, addresses ...string) (*Session, error) { + return testCluster(proto, addresses...).CreateSession() } func TestDNSLookupConnected(t *testing.T) { @@ -147,6 +154,10 @@ Logger = &defaultLogger{} }() + // Override the defaul DNS resolver and restore at the end + failDNS = true + defer func() { failDNS = false }() + srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() @@ -173,8 +184,9 @@ Logger = &defaultLogger{} }() - srv := NewTestServer(t, defaultProto, context.Background()) - defer srv.Stop() + // Override the defaul DNS resolver and restore at the end + failDNS = true + defer func() { failDNS = false }() cluster := NewCluster("cassandra1.invalid", "cassandra2.invalid") cluster.ProtoVersion = int(defaultProto) @@ -246,7 +258,7 @@ srv := NewTestServer(t, defaultProto, ctx) defer srv.Stop() - db, err := newTestSession(srv.Address, defaultProto) + db, err := newTestSession(defaultProto, srv.Address) if err != nil { t.Fatalf("NewCluster: %v", err) } @@ -273,6 +285,57 @@ wg.Wait() } +func TestCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + srv := NewTestServer(t, defaultProto, ctx) + defer srv.Stop() + + cluster := testCluster(defaultProto, srv.Address) + cluster.Timeout = 1 * time.Second + db, err := cluster.CreateSession() + if err != nil { + t.Fatalf("NewCluster: %v", err) + } + defer db.Close() + + qry := db.Query("timeout").WithContext(ctx) + + // Make sure we finish the query without leftovers + var wg sync.WaitGroup + wg.Add(1) + + go func() { + if err := qry.Exec(); err != context.Canceled { + t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err) + } + wg.Done() + }() + + // The query will timeout after about 1 seconds, so cancel it after a short pause + time.AfterFunc(20*time.Millisecond, cancel) + wg.Wait() +} + +type testQueryObserver struct { + metrics map[string]*hostMetrics + verbose bool +} + +func (o *testQueryObserver) ObserveQuery(ctx context.Context, q ObservedQuery) { + host := q.Host.ConnectAddress().String() + o.metrics[host] = q.Metrics + if o.verbose { + Logger.Printf("Observed query %q. Returned %v rows, took %v on host %q with %v attempts and total latency %v. Error: %q\n", + q.Statement, q.Rows, q.End.Sub(q.Start), host, q.Metrics.Attempts, q.Metrics.TotalLatency, q.Err) + } +} + +func (o *testQueryObserver) GetMetrics(host *HostInfo) *hostMetrics { + return o.metrics[host.ConnectAddress().String()] +} + // TestQueryRetry will test to make sure that gocql will execute // the exact amount of retry queries designated by the user. func TestQueryRetry(t *testing.T) { @@ -282,7 +345,7 @@ srv := NewTestServer(t, defaultProto, ctx) defer srv.Stop() - db, err := newTestSession(srv.Address, defaultProto) + db, err := newTestSession(defaultProto, srv.Address) if err != nil { t.Fatalf("NewCluster: %v", err) } @@ -316,13 +379,152 @@ } } +func TestQueryMultinodeWithMetrics(t *testing.T) { + log := &testLogger{} + Logger = log + defer func() { + Logger = &defaultLogger{} + os.Stdout.WriteString(log.String()) + }() + + // Build a 3 node cluster to test host metric mapping + var nodes []*TestServer + var addresses = []string{ + "127.0.0.1", + "127.0.0.2", + "127.0.0.3", + } + // Can do with 1 context for all servers + ctx := context.Background() + for _, ip := range addresses { + srv := NewTestServerWithAddress(ip+":0", t, defaultProto, ctx) + defer srv.Stop() + nodes = append(nodes, srv) + } + + db, err := newTestSession(defaultProto, nodes[0].Address, nodes[1].Address, nodes[2].Address) + if err != nil { + t.Fatalf("NewCluster: %v", err) + } + defer db.Close() + + // 1 retry per host + rt := &SimpleRetryPolicy{NumRetries: 3} + observer := &testQueryObserver{metrics: make(map[string]*hostMetrics), verbose: false} + qry := db.Query("kill").RetryPolicy(rt).Observer(observer) + if err := qry.Exec(); err == nil { + t.Fatalf("expected error") + } + + for i, ip := range addresses { + host := &HostInfo{connectAddress: net.ParseIP(ip)} + queryMetric := qry.metrics.hostMetrics(host) + observedMetrics := observer.GetMetrics(host) + + requests := int(atomic.LoadInt64(&nodes[i].nKillReq)) + hostAttempts := queryMetric.Attempts + if requests != hostAttempts { + t.Fatalf("expected requests %v to match query attempts %v", requests, hostAttempts) + } + + if hostAttempts != observedMetrics.Attempts { + t.Fatalf("expected observed attempts %v to match query attempts %v on host %v", observedMetrics.Attempts, hostAttempts, ip) + } + + hostLatency := queryMetric.TotalLatency + observedLatency := observedMetrics.TotalLatency + if hostLatency != observedLatency { + t.Fatalf("expected observed latency %v to match query latency %v on host %v", observedLatency, hostLatency, ip) + } + } + // the query will only be attempted once, but is being retried + attempts := qry.Attempts() + if attempts != rt.NumRetries { + t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, attempts) + } + +} + +type testRetryPolicy struct { + NumRetries int +} + +func (t *testRetryPolicy) Attempt(qry RetryableQuery) bool { + return qry.Attempts() <= t.NumRetries +} +func (t *testRetryPolicy) GetRetryType(err error) RetryType { + return Retry +} + +func TestSpeculativeExecution(t *testing.T) { + log := &testLogger{} + Logger = log + defer func() { + Logger = &defaultLogger{} + os.Stdout.WriteString(log.String()) + }() + + // Build a 3 node cluster + var nodes []*TestServer + var addresses = []string{ + "127.0.0.1", + "127.0.0.2", + "127.0.0.3", + } + // Can do with 1 context for all servers + ctx := context.Background() + for _, ip := range addresses { + srv := NewTestServerWithAddress(ip+":0", t, defaultProto, ctx) + defer srv.Stop() + nodes = append(nodes, srv) + } + + db, err := newTestSession(defaultProto, nodes[0].Address, nodes[1].Address, nodes[2].Address) + if err != nil { + t.Fatalf("NewCluster: %v", err) + } + defer db.Close() + + // Create a test retry policy, 6 retries will cover 2 executions + rt := &testRetryPolicy{NumRetries: 8} + // test Speculative policy with 1 additional execution + sp := &SimpleSpeculativeExecution{NumAttempts: 1, TimeoutDelay: 200 * time.Millisecond} + + // Build the query + qry := db.Query("speculative").RetryPolicy(rt).SetSpeculativeExecutionPolicy(sp).Idempotent(true) + + // Execute the query and close, check that it doesn't error out + if err := qry.Exec(); err != nil { + t.Errorf("The query failed with '%v'!\n", err) + } + requests1 := atomic.LoadInt64(&nodes[0].nKillReq) + requests2 := atomic.LoadInt64(&nodes[1].nKillReq) + requests3 := atomic.LoadInt64(&nodes[2].nKillReq) + + // Spec Attempts == 1, so expecting to see only 1 regular + 1 speculative = 2 nodes attempted + if requests1 != 0 && requests2 != 0 && requests3 != 0 { + t.Error("error: all 3 nodes were attempted, should have been only 2") + } + + // Only the 4th request will generate results, so + if requests1 != 4 && requests2 != 4 && requests3 != 4 { + t.Error("error: none of 3 nodes was attempted 4 times!") + } + + // "speculative" query will succeed on one arbitrary node after 4 attempts, so + // expecting to see 4 (on successful node) + not more than 2 (as cancelled on another node) == 6 + if requests1+requests2+requests3 > 6 { + t.Errorf("error: expected to see 6 attempts, got %v\n", requests1+requests2+requests3) + } +} + func TestStreams_Protocol1(t *testing.T) { srv := NewTestServer(t, protoVersion1, context.Background()) defer srv.Stop() // TODO: these are more like session tests and should instead operate // on a single Conn - cluster := testCluster(srv.Address, protoVersion1) + cluster := testCluster(protoVersion1, srv.Address) cluster.NumConns = 1 cluster.ProtoVersion = 1 @@ -354,7 +556,7 @@ // TODO: these are more like session tests and should instead operate // on a single Conn - cluster := testCluster(srv.Address, protoVersion3) + cluster := testCluster(protoVersion3, srv.Address) cluster.NumConns = 1 cluster.ProtoVersion = 3 @@ -430,7 +632,7 @@ srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() - cluster := testCluster(srv.Address, defaultProto) + cluster := testCluster(defaultProto, srv.Address) // Set the timeout arbitrarily low so that the query hits the timeout in a // timely manner. cluster.Timeout = 1 * time.Millisecond @@ -457,7 +659,7 @@ if err != ErrTimeoutNoResponse { t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err) } - case <-time.After(10*time.Millisecond + db.cfg.Timeout): + case <-time.After(40*time.Millisecond + db.cfg.Timeout): // ensure that the query goroutines have been scheduled t.Fatalf("query did not timeout after %v", db.cfg.Timeout) } @@ -467,7 +669,7 @@ srv := NewTestServer(b, 3, context.Background()) defer srv.Stop() - cluster := testCluster(srv.Address, 3) + cluster := testCluster(3, srv.Address) // Set the timeout arbitrarily low so that the query hits the timeout in a // timely manner. cluster.Timeout = 500 * time.Millisecond @@ -498,7 +700,7 @@ srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() - cluster := testCluster(srv.Address, defaultProto) + cluster := testCluster(defaultProto, srv.Address) // Set the timeout arbitrarily low so that the query hits the timeout in a // timely manner. cluster.Timeout = 1 * time.Millisecond @@ -522,7 +724,7 @@ srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() - cluster := testCluster(srv.Address, defaultProto) + cluster := testCluster(defaultProto, srv.Address) // Set the timeout arbitrarily low so that the query hits the timeout in a // timely manner. cluster.Timeout = 1000 * time.Millisecond @@ -557,96 +759,262 @@ // TODO: replace this with type check const expErr = "gocql: received unexpected frame on stream 0" + var buf bytes.Buffer + f := newFramer(nil, &buf, nil, protoVersion4) + f.writeHeader(0, opResult, 0) + f.writeInt(resultKindVoid) + f.wbuf[0] |= 0x80 + if err := f.finishWrite(); err != nil { + t.Fatal(err) + } + + conn := &Conn{ + r: bufio.NewReader(&buf), + streams: streams.New(protoVersion4), + } + + err := conn.recv() + if err == nil { + t.Fatal("expected to get an error on stream 0") + } else if !strings.HasPrefix(err.Error(), expErr) { + t.Fatalf("expected to get error prefix %q got %q", expErr, err.Error()) + } +} + +func TestContext_Timeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() srv := NewTestServer(t, defaultProto, ctx) defer srv.Stop() - errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) { - if !srv.isClosed() && !strings.HasPrefix(err.Error(), expErr) { - select { - case <-ctx.Done(): - return - default: - t.Errorf("expected to get error prefix %q got %q", expErr, err.Error()) - } - } - }) - - conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession()) + cluster := testCluster(defaultProto, srv.Address) + cluster.Timeout = 5 * time.Second + db, err := cluster.CreateSession() if err != nil { t.Fatal(err) } + defer db.Close() - writer := frameWriterFunc(func(f *framer, streamID int) error { - f.writeQueryFrame(0, "void", &queryParams{}) - return f.finishWrite() - }) + ctx, cancel = context.WithCancel(ctx) + cancel() - // need to write out an invalid frame, which we need a connection to do - framer, err := conn.exec(ctx, writer, nil) - if err == nil { - t.Fatal("expected to get an error on stream 0") - } else if !strings.HasPrefix(err.Error(), expErr) { - t.Fatalf("expected to get error prefix %q got %q", expErr, err.Error()) - } else if framer != nil { - frame, err := framer.parseFrame() + err = db.Query("timeout").WithContext(ctx).Exec() + if err != context.Canceled { + t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err) + } +} + +// tcpConnPair returns a matching set of a TCP client side and server side connection. +func tcpConnPair() (s, c net.Conn, err error) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + // maybe ipv6 works, if ipv4 fails? + l, err = net.Listen("tcp6", "[::1]:0") if err != nil { - t.Fatal(err) + return nil, nil, err + } + } + defer l.Close() // we only try to accept one connection, so will stop listening. + + addr := l.Addr() + done := make(chan struct{}) + var errDial error + go func(done chan<- struct{}) { + c, errDial = net.Dial(addr.Network(), addr.String()) + close(done) + }(done) + + s, err = l.Accept() + <-done + + if err == nil { + err = errDial + } + + if err != nil { + if s != nil { + s.Close() + } + if c != nil { + c.Close() } - t.Fatalf("got frame %v", frame) } + + return s, c, err } -func TestConnClosedBlocked(t *testing.T) { - // issue 664 - const proto = 3 +func TestWriteCoalescing(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + server, client, err := tcpConnPair() + if err != nil { + t.Fatal(err) + } - srv := NewTestServer(t, proto, context.Background()) - defer srv.Stop() - errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) { - t.Log(err) - }) + done := make(chan struct{}, 1) + var ( + buf bytes.Buffer + bufMutex sync.Mutex + ) + go func() { + defer close(done) + defer server.Close() + var err error + b := make([]byte, 256) + var n int + for { + if n, err = server.Read(b); err != nil { + break + } + bufMutex.Lock() + buf.Write(b[:n]) + bufMutex.Unlock() + } + if err != io.EOF { + t.Errorf("unexpected read error: %v", err) + } + }() + w := &writeCoalescer{ + c: client, + writeCh: make(chan struct{}), + cond: sync.NewCond(&sync.Mutex{}), + quit: ctx.Done(), + running: true, + } + + go func() { + if _, err := w.Write([]byte("one")); err != nil { + t.Error(err) + } + }() + + go func() { + if _, err := w.Write([]byte("two")); err != nil { + t.Error(err) + } + }() + + bufMutex.Lock() + if buf.Len() != 0 { + t.Fatalf("expected buffer to be empty have: %v", buf.String()) + } + bufMutex.Unlock() + + for true { + w.cond.L.Lock() + if len(w.buffers) == 2 { + w.cond.L.Unlock() + break + } + w.cond.L.Unlock() + } + + w.flush() + client.Close() + <-done - conn, err := Connect(srv.host(), &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, createTestSession()) + if got := buf.String(); got != "onetwo" && got != "twoone" { + t.Fatalf("expected to get %q got %q", "onetwo or twoone", got) + } +} + +func TestWriteCoalescing_WriteAfterClose(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var buf bytes.Buffer + defer cancel() + server, client, err := tcpConnPair() if err != nil { t.Fatal(err) } - if err := conn.conn.Close(); err != nil { + done := make(chan struct{}, 1) + go func() { + io.Copy(&buf, server) + server.Close() + close(done) + }() + w := newWriteCoalescer(client, 0, 5*time.Millisecond, ctx.Done()) + + // ensure 1 write works + if _, err := w.Write([]byte("one")); err != nil { t.Fatal(err) } - // This will block indefintaly if #664 is not fixed - err = conn.executeQuery(&Query{stmt: "void"}).Close() - if !strings.HasSuffix(err.Error(), "use of closed network connection") { - t.Fatalf("expected to get use of closed networking connection error got: %v\n", err) + client.Close() + <-done + if v := buf.String(); v != "one" { + t.Fatalf("expected buffer to be %q got %q", "one", v) + } + + // now close and do a write, we should error + cancel() + client.Close() // close client conn too, since server won't see the answer anyway. + + if _, err := w.Write([]byte("two")); err == nil { + t.Fatal("expected to get error for write after closing") + } else if err != io.EOF { + t.Fatalf("expected to get EOF got %v", err) } } -func TestContext_Timeout(t *testing.T) { +type recordingFrameHeaderObserver struct { + t *testing.T + mu sync.Mutex + frames []ObservedFrameHeader +} + +func (r *recordingFrameHeaderObserver) ObserveFrameHeader(ctx context.Context, frm ObservedFrameHeader) { + r.mu.Lock() + r.frames = append(r.frames, frm) + r.mu.Unlock() +} + +func (r *recordingFrameHeaderObserver) getFrames() []ObservedFrameHeader { + r.mu.Lock() + defer r.mu.Unlock() + return r.frames +} + +func TestFrameHeaderObserver(t *testing.T) { srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() - cluster := testCluster(srv.Address, defaultProto) - cluster.Timeout = 5 * time.Second + cluster := testCluster(defaultProto, srv.Address) + cluster.NumConns = 1 + observer := &recordingFrameHeaderObserver{t: t} + cluster.FrameHeaderObserver = observer + db, err := cluster.CreateSession() if err != nil { t.Fatal(err) } - defer db.Close() - ctx, cancel := context.WithCancel(context.Background()) - cancel() - err = db.Query("timeout").WithContext(ctx).Exec() - if err != context.Canceled { - t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err) + if err := db.Query("void").Exec(); err != nil { + t.Fatal(err) + } + + frames := observer.getFrames() + expFrames := []frameOp{opSupported, opReady, opResult} + if len(frames) != len(expFrames) { + t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames)) + } + + for i, op := range expFrames { + if op != frames[i].Opcode { + t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i]) + } + } + voidResultFrame := frames[2] + if voidResultFrame.Length != int32(4) { + t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length) } } -func NewTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer { - laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") +func NewTestServerWithAddress(addr string, t testing.TB, protocol uint8, ctx context.Context) *TestServer { + laddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { t.Fatal(err) } @@ -678,6 +1046,10 @@ return srv } +func NewTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer { + return NewTestServerWithAddress("127.0.0.1:0", t, protocol, ctx) +} + func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer { pem, err := ioutil.ReadFile("testdata/pki/ca.crt") certPool := x509.NewCertPool() @@ -737,12 +1109,16 @@ closed bool } +func (srv *TestServer) session() (*Session, error) { + return testCluster(protoVersion(srv.protocol), srv.Address).CreateSession() +} + func (srv *TestServer) host() *HostInfo { - host, err := hostInfo(srv.Address, 9042) + hosts, err := hostInfo(srv.Address, 9042) if err != nil { srv.t.Fatal(err) } - return host + return hosts[0] } func (srv *TestServer) closeWatch() { @@ -756,13 +1132,7 @@ func (srv *TestServer) serve() { defer srv.listen.Close() - for { - select { - case <-srv.ctx.Done(): - return - default: - } - + for !srv.isClosed() { conn, err := srv.listen.Accept() if err != nil { break @@ -770,26 +1140,13 @@ go func(conn net.Conn) { defer conn.Close() - for { - select { - case <-srv.ctx.Done(): - return - default: - } - + for !srv.isClosed() { framer, err := srv.readFrame(conn) if err != nil { if err == io.EOF { return } - - select { - case <-srv.ctx.Done(): - return - default: - } - - srv.t.Error(err) + srv.errorLocked(err) return } @@ -824,16 +1181,19 @@ srv.closeLocked() } +func (srv *TestServer) errorLocked(err interface{}) { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.closed { + return + } + srv.t.Error(err) +} + func (srv *TestServer) process(f *framer) { head := f.header if head == nil { - select { - case <-srv.ctx.Done(): - return - default: - } - - srv.t.Error("process frame with a nil header") + srv.errorLocked("process frame with a nil header") return } @@ -885,6 +1245,19 @@ } }() return + case "speculative": + atomic.AddInt64(&srv.nKillReq, 1) + if atomic.LoadInt64(&srv.nKillReq) > 3 { + f.writeHeader(0, opResult, head.stream) + f.writeInt(resultKindVoid) + f.writeString("speculative query success on the node " + srv.Address) + } else { + f.writeHeader(0, opError, head.stream) + f.writeInt(0x1001) + f.writeString("speculative error") + rand.Seed(time.Now().UnixNano()) + <-time.After(time.Millisecond * 120) + } default: f.writeHeader(0, opResult, head.stream) f.writeInt(resultKindVoid) @@ -901,13 +1274,7 @@ f.wbuf[0] = srv.protocol | 0x80 if err := f.finishWrite(); err != nil { - select { - case <-srv.ctx.Done(): - return - default: - } - - srv.t.Error(err) + srv.errorLocked(err) } } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/control.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/control.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/control.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/control.go 2019-11-02 13:15:23.000000000 +0000 @@ -7,6 +7,7 @@ "fmt" "math/rand" "net" + "os" "regexp" "strconv" "sync" @@ -31,13 +32,15 @@ // Ensure that the atomic variable is aligned to a 64bit boundary // so that atomic operations can be applied on 32bit architectures. type controlConn struct { + started int32 + reconnecting int32 + session *Session conn atomic.Value retry RetryPolicy - started int32 - quit chan struct{} + quit chan struct{} } func createControlConn(session *Session) *controlConn { @@ -47,7 +50,7 @@ retry: &SimpleRetryPolicy{NumRetries: 3}, } - control.conn.Store((*Conn)(nil)) + control.conn.Store((*connHost)(nil)) return control } @@ -58,12 +61,16 @@ } sleepTime := 1 * time.Second + timer := time.NewTimer(sleepTime) + defer timer.Stop() for { + timer.Reset(sleepTime) + select { case <-c.quit: return - case <-time.After(sleepTime): + case <-timer.C: } resp, err := c.writeFrame(&writeOptionsFrame{}) @@ -86,14 +93,13 @@ // try to connect a bit faster sleepTime = 1 * time.Second c.reconnect(true) - // time.Sleep(5 * time.Second) continue } } -var hostLookupPreferV4 = false +var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true" -func hostInfo(addr string, defaultPort int) (*HostInfo, error) { +func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) { var port int host, portStr, err := net.SplitHostPort(addr) if err != nil { @@ -106,44 +112,51 @@ } } - ip := net.ParseIP(host) - if ip == nil { - ips, err := net.LookupIP(host) - if err != nil { - return nil, err - } else if len(ips) == 0 { - return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr) - } + var hosts []*HostInfo - if hostLookupPreferV4 { - for _, v := range ips { - if v4 := v.To4(); v4 != nil { - ip = v4 - break - } - } - if ip == nil { - ip = ips[0] + // Check if host is a literal IP address + if ip := net.ParseIP(host); ip != nil { + hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port}) + return hosts, nil + } + + // Look up host in DNS + ips, err := LookupIP(host) + if err != nil { + return nil, err + } else if len(ips) == 0 { + return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr) + } + + // Filter to v4 addresses if any present + if hostLookupPreferV4 { + var preferredIPs []net.IP + for _, v := range ips { + if v4 := v.To4(); v4 != nil { + preferredIPs = append(preferredIPs, v4) } - } else { - // TODO(zariel): should we check that we can connect to any of the ips? - ip = ips[0] } + if len(preferredIPs) != 0 { + ips = preferredIPs + } + } + for _, ip := range ips { + hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port}) } - return &HostInfo{connectAddress: ip, port: port}, nil + return hosts, nil } func shuffleHosts(hosts []*HostInfo) []*HostInfo { - mutRandr.Lock() - perm := randr.Perm(len(hosts)) - mutRandr.Unlock() shuffled := make([]*HostInfo, len(hosts)) + copy(shuffled, hosts) - for i, host := range hosts { - shuffled[perm[i]] = host - } + mutRandr.Lock() + randr.Shuffle(len(hosts), func(i, j int) { + shuffled[i], shuffled[j] = shuffled[j], shuffled[i] + }) + mutRandr.Unlock() return shuffled } @@ -153,10 +166,13 @@ // node. shuffled := shuffleHosts(endpoints) + cfg := *c.session.connCfg + cfg.disableCoalesce = true + var err error for _, host := range shuffled { var conn *Conn - conn, err = c.session.connect(host, c) + conn, err = c.session.dial(host, &cfg, c) if err == nil { return conn, nil } @@ -197,14 +213,20 @@ handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) { // we should never get here, but if we do it means we connected to a // host successfully which means our attempted protocol version worked + if !closed { + c.Close() + } }) var err error for _, host := range hosts { var conn *Conn - conn, err = Connect(host, &connCfg, handler, c.session) - if err == nil { + conn, err = c.session.dial(host, &connCfg, handler) + if conn != nil { conn.Close() + } + + if err == nil { return connCfg.ProtoVersion, nil } @@ -239,35 +261,31 @@ return nil } +type connHost struct { + conn *Conn + host *HostInfo +} + func (c *controlConn) setupConn(conn *Conn) error { if err := c.registerEvents(conn); err != nil { conn.Close() return err } - c.conn.Store(conn) - - if v, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok { - c.session.handleNodeUp(copyBytes(v.IP), v.Port, false) - return nil - } - - host, portstr, err := net.SplitHostPort(conn.conn.RemoteAddr().String()) - if err != nil { - return err - } - - port, err := strconv.Atoi(portstr) + // TODO(zariel): do we need to fetch host info everytime + // the control conn connects? Surely we have it cached? + host, err := conn.localHostInfo(context.TODO()) if err != nil { return err } - ip := net.ParseIP(host) - if ip == nil { - return fmt.Errorf("invalid remote addr: addr=%v host=%q", conn.conn.RemoteAddr(), host) + ch := &connHost{ + conn: conn, + host: host, } - c.session.handleNodeUp(ip, port, false) + c.conn.Store(ch) + c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false) return nil } @@ -308,14 +326,18 @@ } func (c *controlConn) reconnect(refreshring bool) { + if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) { + return + } + defer atomic.StoreInt32(&c.reconnecting, 0) // TODO: simplify this function, use session.ring to get hosts instead of the // connection pool var host *HostInfo - oldConn := c.conn.Load().(*Conn) - if oldConn != nil { - host = oldConn.host - oldConn.Close() + ch := c.getConn() + if ch != nil { + host = ch.host + ch.conn.Close() } var newConn *Conn @@ -325,7 +347,9 @@ if err != nil { // host is dead // TODO: this is replicated in a few places - c.session.handleNodeDown(host.ConnectAddress(), host.Port()) + if c.session.cfg.ConvictionPolicy.AddFailure(err, host) { + c.session.handleNodeDown(host.ConnectAddress(), host.Port()) + } } else { newConn = conn } @@ -364,21 +388,28 @@ return } - oldConn := c.conn.Load().(*Conn) - if oldConn != conn { + oldConn := c.getConn() + + // If connection has long gone, and not been attempted for awhile, + // it's possible to have oldConn as nil here (#1297). + if oldConn != nil && oldConn.conn != conn { return } - c.reconnect(true) + c.reconnect(false) +} + +func (c *controlConn) getConn() *connHost { + return c.conn.Load().(*connHost) } func (c *controlConn) writeFrame(w frameWriter) (frame, error) { - conn := c.conn.Load().(*Conn) - if conn == nil { + ch := c.getConn() + if ch == nil { return nil, errNoControl } - framer, err := conn.exec(context.Background(), w, nil) + framer, err := ch.conn.exec(context.Background(), w, nil) if err != nil { return nil, err } @@ -386,13 +417,13 @@ return framer.parseFrame() } -func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { +func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter { const maxConnectAttempts = 5 connectAttempts := 0 for i := 0; i < maxConnectAttempts; i++ { - conn := c.conn.Load().(*Conn) - if conn == nil { + ch := c.getConn() + if ch == nil { if connectAttempts > maxConnectAttempts { break } @@ -403,26 +434,32 @@ continue } - return fn(conn) + return fn(ch) } return &Iter{err: errNoControl} } +func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { + return c.withConnHost(func(ch *connHost) *Iter { + return fn(ch.conn) + }) +} + // query will return nil if the connection is closed or nil func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) { q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil) for { iter = c.withConn(func(conn *Conn) *Iter { - return conn.executeQuery(q) + return conn.executeQuery(context.TODO(), q) }) if gocqlDebug && iter.err != nil { Logger.Printf("control: error executing %q: %v\n", statement, iter.err) } - q.attempts++ + q.AddAttempts(1, c.getConn().host) if iter.err == nil || !c.retry.Attempt(q) { break } @@ -433,25 +470,18 @@ func (c *controlConn) awaitSchemaAgreement() error { return c.withConn(func(conn *Conn) *Iter { - return &Iter{err: conn.awaitSchemaAgreement()} + return &Iter{err: conn.awaitSchemaAgreement(context.TODO())} }).err } -func (c *controlConn) GetHostInfo() *HostInfo { - conn := c.conn.Load().(*Conn) - if conn == nil { - return nil - } - return conn.host -} - func (c *controlConn) close() { if atomic.CompareAndSwapInt32(&c.started, 1, -1) { c.quit <- struct{}{} } - conn := c.conn.Load().(*Conn) - if conn != nil { - conn.Close() + + ch := c.getConn() + if ch != nil { + ch.conn.Close() } } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/control_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/control_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/control_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/control_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -18,12 +18,13 @@ } for i, test := range tests { - host, err := hostInfo(test.addr, 1) + hosts, err := hostInfo(test.addr, 1) if err != nil { t.Errorf("%d: %v", i, err) continue } + host := hosts[0] if !host.ConnectAddress().Equal(test.ip) { t.Errorf("expected ip %v got %v for addr %q", test.ip, host.ConnectAddress(), test.addr) } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cqltypes.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cqltypes.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/cqltypes.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/cqltypes.go 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,11 @@ +// Copyright (c) 2012 The gocql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gocql + +type Duration struct { + Months int32 + Days int32 + Nanoseconds int64 +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/changelog golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/changelog --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/changelog 2018-07-04 11:00:56.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/changelog 2019-11-02 14:15:16.000000000 +0000 @@ -1,3 +1,13 @@ +golang-github-gocql-gocql (0.0~git20191102.0.9faa4c0-1) unstable; urgency=medium + + * New upstream snapshot. + This introduces compatibility with Go 1.13 (e.g. tests). + * Drop patch applied by upstream. + * Bump Standards-Version. + * Use debhelper 12. + + -- Sascha Steinbiss Sat, 02 Nov 2019 15:15:16 +0100 + golang-github-gocql-gocql (0.0~git20171009.0.2416cf3-3) unstable; urgency=medium [ Alexandre Viau ] diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/compat golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/compat --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/compat 2018-07-04 11:00:56.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/compat 2019-11-02 14:15:16.000000000 +0000 @@ -1 +1 @@ -11 +12 diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/control golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/control --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/control 2018-07-04 11:00:56.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/control 2019-11-02 14:15:16.000000000 +0000 @@ -3,13 +3,13 @@ Priority: optional Maintainer: Debian Go Packaging Team Uploaders: Sascha Steinbiss -Build-Depends: debhelper (>= 11), +Build-Depends: debhelper (>= 12), dh-golang, golang-any, golang-github-golang-snappy-dev, golang-github-hailocab-go-hostpool-dev, golang-gopkg-inf.v0-dev -Standards-Version: 4.1.4 +Standards-Version: 4.4.1 Homepage: https://github.com/gocql/gocql Vcs-Browser: https://salsa.debian.org/go-team/packages/golang-github-gocql-gocql Vcs-Git: https://salsa.debian.org/go-team/packages/golang-github-gocql-gocql.git diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/patches/fix-ftbfs-on-32bit.patch golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/patches/fix-ftbfs-on-32bit.patch --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/patches/fix-ftbfs-on-32bit.patch 2018-07-04 11:00:52.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/patches/fix-ftbfs-on-32bit.patch 1970-01-01 00:00:00.000000000 +0000 @@ -1,21 +0,0 @@ -Description: fix FTBFS on 32-bit platforms - During Debian autobuilds on 32-bit platforms (armhf, i386) we noticed that - one of the expected result hash values, when given as a literal number, - overflows int. This causes the build of the testing code to fail on such - platforms. Explicitly requiring the value to be 64-bit fixes this. -Author: Sascha Steinbiss -Forwarded: https://github.com/gocql/gocql/pull/1008 -Last-Update: 2017-10-27 ---- a/internal/murmur/murmur_test.go -+++ b/internal/murmur/murmur_test.go -@@ -66,8 +66,8 @@ - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - h1 := Murmur3H1(data) -- if h1 != 7627370222079200297 { -- b.Fatalf("expected %d got %d", 7627370222079200297, h1) -+ if h1 != uint64(7627370222079200297) { -+ b.Fatalf("expected %d got %d", uint64(7627370222079200297), h1) - } - } - }) diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/patches/series golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/patches/series --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/debian/patches/series 2018-07-04 11:00:52.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/debian/patches/series 1970-01-01 00:00:00.000000000 +0000 @@ -1 +0,0 @@ -fix-ftbfs-on-32bit.patch diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/doc.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/doc.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/doc.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/doc.go 2019-11-02 13:15:23.000000000 +0000 @@ -4,6 +4,6 @@ // Package gocql implements a fast and robust Cassandra driver for the // Go programming language. -package gocql +package gocql // import "github.com/gocql/gocql" // TODO(tux21b): write more docs. diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/errors.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/errors.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/errors.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/errors.go 2019-11-02 13:15:23.000000000 +0000 @@ -15,6 +15,7 @@ errReadFailure = 0x1300 errFunctionFailure = 0x1400 errWriteFailure = 0x1500 + errCDCWriteFailure = 0x1600 errSyntax = 0x2000 errUnauthorized = 0x2100 errInvalid = 0x2200 @@ -63,6 +64,8 @@ return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive) } +type ErrorMap map[string]uint16 + type RequestErrWriteTimeout struct { errorFrame Consistency Consistency @@ -78,6 +81,11 @@ BlockFor int NumFailures int WriteType string + ErrorMap ErrorMap +} + +type RequestErrCDCWriteFailure struct { + errorFrame } type RequestErrReadTimeout struct { @@ -106,6 +114,7 @@ BlockFor int NumFailures int DataPresent bool + ErrorMap ErrorMap } type RequestErrFunctionFailure struct { diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/errors_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/errors_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/errors_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/errors_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,4 +1,4 @@ -// +build all integration +// +build all cassandra package gocql diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/events.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/events.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/events.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/events.go 2019-11-02 13:15:23.000000000 +0000 @@ -80,9 +80,6 @@ } func (s *Session) handleEvent(framer *framer) { - // TODO(zariel): need to debounce events frames, and possible also events - defer framerPool.Put(framer) - frame, err := framer.parseFrame() if err != nil { // TODO: logger @@ -94,9 +91,10 @@ Logger.Printf("gocql: handling frame: %v\n", frame) } - // TODO: handle medatadata events switch f := frame.(type) { - case *schemaChangeKeyspace, *schemaChangeFunction, *schemaChangeTable: + case *schemaChangeKeyspace, *schemaChangeFunction, + *schemaChangeTable, *schemaChangeAggregate, *schemaChangeType: + s.schemaEvents.debounce(frame) case *topologyChangeEventFrame, *statusChangeEventFrame: s.nodeEvents.debounce(frame) @@ -106,22 +104,29 @@ } func (s *Session) handleSchemaEvent(frames []frame) { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.schemaDescriber == nil { - return - } + // TODO: debounce events for _, frame := range frames { switch f := frame.(type) { case *schemaChangeKeyspace: s.schemaDescriber.clearSchema(f.keyspace) + s.handleKeyspaceChange(f.keyspace, f.change) case *schemaChangeTable: s.schemaDescriber.clearSchema(f.keyspace) + case *schemaChangeAggregate: + s.schemaDescriber.clearSchema(f.keyspace) + case *schemaChangeFunction: + s.schemaDescriber.clearSchema(f.keyspace) + case *schemaChangeType: + s.schemaDescriber.clearSchema(f.keyspace) } } } +func (s *Session) handleKeyspaceChange(keyspace, change string) { + s.control.awaitSchemaAgreement() + s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change}) +} + func (s *Session) handleNodeEvent(frames []frame) { type nodeEvent struct { change string @@ -173,16 +178,30 @@ } } +func (s *Session) addNewNode(host *HostInfo) { + if s.cfg.filterHost(host) { + return + } + + host.setState(NodeUp) + s.pool.addHost(host) + s.policy.AddHost(host) +} + func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) { + if gocqlDebug { + Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port) + } + + ip, port = s.cfg.translateAddressPort(ip, port) + // Get host info and apply any filters to the host - hostInfo, err := s.hostSource.GetHostInfo(ip, port) + hostInfo, err := s.hostSource.getHostInfo(ip, port) if err != nil { Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err) return - } - - // If hostInfo is nil, this host was filtered out by cfg.HostFilter - if hostInfo == nil { + } else if hostInfo == nil { + // If hostInfo is nil, this host was filtered out by cfg.HostFilter return } @@ -191,20 +210,23 @@ } // should this handle token moving? - if existing, ok := s.ring.addHostIfMissing(hostInfo); ok { - existing.update(hostInfo) - hostInfo = existing - } + hostInfo = s.ring.addOrUpdate(hostInfo) + + s.addNewNode(hostInfo) - s.pool.addHost(hostInfo) - s.policy.AddHost(hostInfo) - hostInfo.setState(NodeUp) if s.control != nil && !s.cfg.IgnorePeerAddr { + // TODO(zariel): debounce ring refresh s.hostSource.refreshRing() } } func (s *Session) handleRemovedNode(ip net.IP, port int) { + if gocqlDebug { + Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port) + } + + ip, port = s.cfg.translateAddressPort(ip, port) + // we remove all nodes but only add ones which pass the filter host := s.ring.getHost(ip) if host == nil { @@ -225,34 +247,30 @@ } } -func (s *Session) handleNodeUp(ip net.IP, port int, waitForBinary bool) { +func (s *Session) handleNodeUp(eventIp net.IP, eventPort int, waitForBinary bool) { if gocqlDebug { - Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", ip.String(), port) + Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort) } - host := s.ring.getHost(ip) - if host != nil { - // If we receive a node up event and user has asked us to ignore the peer address use - // the address provide by the event instead the address provide by peer the table. - if s.cfg.IgnorePeerAddr && !host.ConnectAddress().Equal(ip) { - host.SetConnectAddress(ip) - } - - if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) { - return - } + ip, _ := s.cfg.translateAddressPort(eventIp, eventPort) - if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary { - time.Sleep(t) - } + host := s.ring.getHost(ip) + if host == nil { + // TODO(zariel): avoid the need to translate twice in this + // case + s.handleNewNode(eventIp, eventPort, waitForBinary) + return + } - s.pool.hostUp(host) - s.policy.HostUp(host) - host.setState(NodeUp) + if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) { return } - s.handleNewNode(ip, port, waitForBinary) + if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary { + time.Sleep(t) + } + + s.addNewNode(host) } func (s *Session) handleNodeDown(ip net.IP, port int) { diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/frame.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/frame.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/frame.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/frame.go 2019-11-02 13:15:23.000000000 +0000 @@ -5,6 +5,7 @@ package gocql import ( + "context" "errors" "fmt" "io" @@ -12,12 +13,18 @@ "net" "runtime" "strings" - "sync" "time" ) type unsetColumn struct{} +// UnsetValue represents a value used in a query binding that will be ignored by Cassandra. +// +// By setting a field to the unset value Cassandra will ignore the write completely. +// The main advantage is the ability to keep the same prepared statement even when you don't +// want to update some fields, where before you needed to make another prepared statement. +// +// UnsetValue is only available when using the version 4 of the protocol. var UnsetValue = unsetColumn{} type namedValue struct { @@ -40,6 +47,7 @@ protoVersion2 = 0x02 protoVersion3 = 0x03 protoVersion4 = 0x04 + protoVersion5 = 0x05 maxFrameSize = 256 * 1024 * 1024 ) @@ -149,12 +157,17 @@ flagWithSerialConsistency byte = 0x10 flagDefaultTimestamp byte = 0x20 flagWithNameValues byte = 0x40 + flagWithKeyspace byte = 0x80 + + // prepare flags + flagWithPreparedKeyspace uint32 = 0x01 // header flags flagCompress byte = 0x01 flagTracing byte = 0x02 flagCustomPayload byte = 0x04 flagWarning byte = 0x08 + flagBetaProtocol byte = 0x10 ) type Consistency uint16 @@ -227,20 +240,30 @@ return nil } -func ParseConsistency(s string) (consistency Consistency, err error) { +func ParseConsistency(s string) Consistency { + var c Consistency + if err := c.UnmarshalText([]byte(strings.ToUpper(s))); err != nil { + panic(err) + } + return c +} + +// ParseConsistencyWrapper wraps gocql.ParseConsistency to provide an err +// return instead of a panic +func ParseConsistencyWrapper(s string) (consistency Consistency, err error) { err = consistency.UnmarshalText([]byte(strings.ToUpper(s))) return } -// ParseConsistencyWrapper is deprecated use ParseConsistency instead. -var ParseConsistencyWrapper = ParseConsistency - -func MustParseConsistency(s string) Consistency { - c, err := ParseConsistency(s) +// MustParseConsistency is the same as ParseConsistency except it returns +// an error (never). It is kept here since breaking changes are not good. +// DEPRECATED: use ParseConsistency if you want a panic on parse error. +func MustParseConsistency(s string) (Consistency, error) { + c, err := ParseConsistencyWrapper(s) if err != nil { panic(err) } - return c + return c, nil } type SerialConsistency uint16 @@ -309,13 +332,12 @@ } type frameHeader struct { - version protoVersion - flags byte - stream int - op frameOp - length int - customPayload map[string][]byte - warnings []string + version protoVersion + flags byte + stream int + op frameOp + length int + warnings []string } func (f frameHeader) String() string { @@ -328,13 +350,32 @@ const defaultBufSize = 128 -var framerPool = sync.Pool{ - New: func() interface{} { - return &framer{ - wbuf: make([]byte, defaultBufSize), - readBuffer: make([]byte, defaultBufSize), - } - }, +type ObservedFrameHeader struct { + Version protoVersion + Flags byte + Stream int16 + Opcode frameOp + Length int32 + + // StartHeader is the time we started reading the frame header off the network connection. + Start time.Time + // EndHeader is the time we finished reading the frame header off the network connection. + End time.Time + + // Host is Host of the connection the frame header was read from. + Host *HostInfo +} + +func (f ObservedFrameHeader) String() string { + return fmt.Sprintf("[observed header version=%s flags=0x%x stream=%d op=%s length=%d]", f.Version, f.Flags, f.Stream, f.Opcode, f.Length) +} + +// FrameHeaderObserver is the interface implemented by frame observers / stat collectors. +// +// Experimental, this interface and use may change +type FrameHeaderObserver interface { + // ObserveFrameHeader gets called on every received frame header. + ObserveFrameHeader(context.Context, ObservedFrameHeader) } // a framer is responsible for reading, writing and parsing frames on a single stream @@ -359,14 +400,22 @@ rbuf []byte wbuf []byte + + customPayload map[string][]byte } func newFramer(r io.Reader, w io.Writer, compressor Compressor, version byte) *framer { - f := framerPool.Get().(*framer) + f := &framer{ + wbuf: make([]byte, defaultBufSize), + readBuffer: make([]byte, defaultBufSize), + } var flags byte if compressor != nil { flags |= flagCompress } + if version == protoVersion5 { + flags |= flagBetaProtocol + } version &= protoVersionMask @@ -404,7 +453,7 @@ version := p[0] & protoVersionMask - if version < protoVersion1 || version > protoVersion4 { + if version < protoVersion1 || version > protoVersion5 { return frameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) } @@ -449,6 +498,11 @@ f.flags |= flagTracing } +// explicitly enables the custom payload flag +func (f *framer) payload() { + f.flags |= flagCustomPayload +} + // reads a frame form the wire into the framers buffer func (f *framer) readFrame(head *frameHeader) error { if head.length < 0 { @@ -513,7 +567,7 @@ } if f.header.flags&flagCustomPayload == flagCustomPayload { - f.header.customPayload = f.readBytesMap() + f.customPayload = f.readBytesMap() } // assumes that the frame body has been read into rbuf @@ -607,7 +661,14 @@ res.Consistency = f.readConsistency() res.Received = f.readInt() res.BlockFor = f.readInt() + if f.proto > protoVersion4 { + res.ErrorMap = f.readErrorMap() + res.NumFailures = len(res.ErrorMap) + } else { + res.NumFailures = f.readInt() + } res.DataPresent = f.readByte() != 0 + return res case errWriteFailure: res := &RequestErrWriteFailure{ @@ -616,17 +677,29 @@ res.Consistency = f.readConsistency() res.Received = f.readInt() res.BlockFor = f.readInt() - res.NumFailures = f.readInt() + if f.proto > protoVersion4 { + res.ErrorMap = f.readErrorMap() + res.NumFailures = len(res.ErrorMap) + } else { + res.NumFailures = f.readInt() + } res.WriteType = f.readString() return res case errFunctionFailure: - res := RequestErrFunctionFailure{ + res := &RequestErrFunctionFailure{ errorFrame: errD, } res.Keyspace = f.readString() res.Function = f.readString() res.ArgTypes = f.readStringList() return res + + case errCDCWriteFailure: + res := &RequestErrCDCWriteFailure{ + errorFrame: errD, + } + return res + case errInvalid, errBootstrapping, errConfig, errCredentials, errOverloaded, errProtocol, errServer, errSyntax, errTruncate, errUnauthorized: // TODO(zariel): we should have some distinct types for these errors @@ -636,6 +709,16 @@ } } +func (f *framer) readErrorMap() (errMap ErrorMap) { + errMap = make(ErrorMap) + numErrs := f.readInt() + for i := 0; i < numErrs; i++ { + ip := f.readInetAdressOnly().String() + errMap[ip] = f.readShort() + } + return +} + func (f *framer) writeHeader(flags byte, op frameOp, stream int) { f.wbuf = f.wbuf[:0] f.wbuf = append(f.wbuf, @@ -745,28 +828,42 @@ return fmt.Sprintf("[startup opts=%+v]", w.opts) } -func (w *writeStartupFrame) writeFrame(framer *framer, streamID int) error { - return framer.writeStartupFrame(streamID, w.opts) -} - -func (f *framer) writeStartupFrame(streamID int, options map[string]string) error { +func (w *writeStartupFrame) writeFrame(f *framer, streamID int) error { f.writeHeader(f.flags&^flagCompress, opStartup, streamID) - f.writeStringMap(options) + f.writeStringMap(w.opts) return f.finishWrite() } type writePrepareFrame struct { - statement string + statement string + keyspace string + customPayload map[string][]byte } -func (w *writePrepareFrame) writeFrame(framer *framer, streamID int) error { - return framer.writePrepareFrame(streamID, w.statement) -} +func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error { + if len(w.customPayload) > 0 { + f.payload() + } + f.writeHeader(f.flags, opPrepare, streamID) + f.writeCustomPayload(&w.customPayload) + f.writeLongString(w.statement) + + var flags uint32 = 0 + if w.keyspace != "" { + if f.proto > protoVersion4 { + flags |= flagWithPreparedKeyspace + } else { + panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher")) + } + } + if f.proto > protoVersion4 { + f.writeUint(flags) + } + if w.keyspace != "" { + f.writeString(w.keyspace) + } -func (f *framer) writePrepareFrame(stream int, statement string) error { - f.writeHeader(f.flags, opPrepare, stream) - f.writeLongString(statement) return f.finishWrite() } @@ -866,7 +963,7 @@ } if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = f.readBytes() + meta.pagingState = copyBytes(f.readBytes()) } if meta.flags&flagNoMetaData == flagNoMetaData { @@ -917,6 +1014,10 @@ actualColCount int } +func (r *resultMetadata) morePages() bool { + return r.flags&flagHasMorePages == flagHasMorePages +} + func (r resultMetadata) String() string { return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns) } @@ -951,7 +1052,7 @@ meta.actualColCount = meta.colCount if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = f.readBytes() + meta.pagingState = copyBytes(f.readBytes()) } if meta.flags&flagNoMetaData == flagNoMetaData { @@ -1102,6 +1203,14 @@ return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.keyspace, f.object) } +type schemaChangeType struct { + frameHeader + + change string + keyspace string + object string +} + type schemaChangeFunction struct { frameHeader @@ -1111,6 +1220,15 @@ args []string } +type schemaChangeAggregate struct { + frameHeader + + change string + keyspace string + name string + args []string +} + func (f *framer) parseResultSchemaChange() frame { if f.proto <= protoVersion2 { change := f.readString() @@ -1146,7 +1264,7 @@ frame.keyspace = f.readString() return frame - case "TABLE", "TYPE": + case "TABLE": frame := &schemaChangeTable{ frameHeader: *f.header, change: change, @@ -1156,7 +1274,17 @@ frame.object = f.readString() return frame - case "FUNCTION", "AGGREGATE": + case "TYPE": + frame := &schemaChangeType{ + frameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + frame.object = f.readString() + + return frame + case "FUNCTION": frame := &schemaChangeFunction{ frameHeader: *f.header, change: change, @@ -1167,6 +1295,17 @@ frame.args = f.readStringList() return frame + case "AGGREGATE": + frame := &schemaChangeAggregate{ + frameHeader: *f.header, + change: change, + } + + frame.keyspace = f.readString() + frame.name = f.readString() + frame.args = f.readStringList() + + return frame default: panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change)) } @@ -1312,11 +1451,13 @@ // v3+ defaultTimestamp bool defaultTimestampValue int64 + // v5+ + keyspace string } func (q queryParams) String() string { - return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v]", - q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values) + return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]", + q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace) } func (f *framer) writeQueryParams(opts *queryParams) { @@ -1357,7 +1498,19 @@ } } - f.writeByte(flags) + if opts.keyspace != "" { + if f.proto > protoVersion4 { + flags |= flagWithKeyspace + } else { + panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher")) + } + } + + if f.proto > protoVersion4 { + f.writeUint(uint32(flags)) + } else { + f.writeByte(flags) + } if n := len(opts.values); n > 0 { f.writeShort(uint16(n)) @@ -1396,11 +1549,18 @@ } f.writeLong(ts) } + + if opts.keyspace != "" { + f.writeString(opts.keyspace) + } } type writeQueryFrame struct { statement string params queryParams + + // v4+ + customPayload map[string][]byte } func (w *writeQueryFrame) String() string { @@ -1408,11 +1568,15 @@ } func (w *writeQueryFrame) writeFrame(framer *framer, streamID int) error { - return framer.writeQueryFrame(streamID, w.statement, &w.params) + return framer.writeQueryFrame(streamID, w.statement, &w.params, w.customPayload) } -func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams) error { +func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams, customPayload map[string][]byte) error { + if len(customPayload) > 0 { + f.payload() + } f.writeHeader(f.flags, opQuery, streamID) + f.writeCustomPayload(&customPayload) f.writeLongString(statement) f.writeQueryParams(params) @@ -1432,6 +1596,9 @@ type writeExecuteFrame struct { preparedID []byte params queryParams + + // v4+ + customPayload map[string][]byte } func (e *writeExecuteFrame) String() string { @@ -1439,11 +1606,15 @@ } func (e *writeExecuteFrame) writeFrame(fr *framer, streamID int) error { - return fr.writeExecuteFrame(streamID, e.preparedID, &e.params) + return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) } -func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams) error { +func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { + if len(*customPayload) > 0 { + f.payload() + } f.writeHeader(f.flags, opExecute, streamID) + f.writeCustomPayload(customPayload) f.writeShortBytes(preparedID) if f.proto > protoVersion1 { f.writeQueryParams(params) @@ -1480,14 +1651,21 @@ serialConsistency SerialConsistency defaultTimestamp bool defaultTimestampValue int64 + + //v4+ + customPayload map[string][]byte } func (w *writeBatchFrame) writeFrame(framer *framer, streamID int) error { - return framer.writeBatchFrame(streamID, w) + return framer.writeBatchFrame(streamID, w, w.customPayload) } -func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame) error { +func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload map[string][]byte) error { + if len(customPayload) > 0 { + f.payload() + } f.writeHeader(f.flags, opBatch, streamID) + f.writeCustomPayload(&customPayload) f.writeByte(byte(w.typ)) n := len(w.statements) @@ -1507,10 +1685,13 @@ f.writeShort(uint16(len(b.values))) for j := range b.values { - col := &b.values[j] + col := b.values[j] if f.proto > protoVersion2 && col.name != "" { // TODO: move this check into the caller and set a flag on writeBatchFrame // to indicate using named values + if f.proto <= protoVersion5 { + return fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") + } flags |= flagWithNameValues f.writeString(col.name) } @@ -1532,7 +1713,11 @@ flags |= flagDefaultTimestamp } - f.writeByte(flags) + if f.proto > protoVersion4 { + f.writeUint(uint32(flags)) + } else { + f.writeByte(flags) + } if w.serialConsistency > 0 { f.writeConsistency(Consistency(w.serialConsistency)) @@ -1559,7 +1744,7 @@ } func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error { - f.writeHeader(f.flags, opOptions, stream) + f.writeHeader(f.flags&^flagCompress, opOptions, stream) return f.finishWrite() } @@ -1700,7 +1885,7 @@ return l } -func (f *framer) readInet() (net.IP, int) { +func (f *framer) readInetAdressOnly() net.IP { if len(f.rbuf) < 1 { panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.rbuf))) } @@ -1719,9 +1904,11 @@ ip := make([]byte, size) copy(ip, f.rbuf[:size]) f.rbuf = f.rbuf[size:] + return net.IP(ip) +} - port := f.readInt() - return net.IP(ip), port +func (f *framer) readInet() (net.IP, int) { + return f.readInetAdressOnly(), f.readInt() } func (f *framer) readConsistency() Consistency { @@ -1730,7 +1917,7 @@ func (f *framer) readStringMap() map[string]string { size := f.readShort() - m := make(map[string]string) + m := make(map[string]string, size) for i := 0; i < int(size); i++ { k := f.readString() @@ -1743,7 +1930,7 @@ func (f *framer) readBytesMap() map[string][]byte { size := f.readShort() - m := make(map[string][]byte) + m := make(map[string][]byte, size) for i := 0; i < int(size); i++ { k := f.readString() @@ -1756,7 +1943,7 @@ func (f *framer) readStringMultiMap() map[string][]string { size := f.readShort() - m := make(map[string][]string) + m := make(map[string][]string, size) for i := 0; i < int(size); i++ { k := f.readString() @@ -1794,6 +1981,13 @@ byte(n)) } +func appendUint(p []byte, n uint32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + func appendLong(p []byte, n int64) []byte { return append(p, byte(n>>56), @@ -1807,11 +2001,24 @@ ) } +func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { + if len(*customPayload) > 0 { + if f.proto < protoVersion4 { + panic("Custom payload is not supported with version V3 or less") + } + f.writeBytesMap(*customPayload) + } +} + // these are protocol level binary types func (f *framer) writeInt(n int32) { f.wbuf = appendInt(f.wbuf, n) } +func (f *framer) writeUint(n uint32) { + f.wbuf = appendUint(f.wbuf, n) +} + func (f *framer) writeShort(n uint16) { f.wbuf = appendShort(f.wbuf, n) } @@ -1889,3 +2096,11 @@ f.writeString(v) } } + +func (f *framer) writeBytesMap(m map[string][]byte) { + f.writeShort(uint16(len(m))) + for k, v := range m { + f.writeString(k) + f.writeBytes(v) + } +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/frame_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/frame_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/frame_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/frame_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -2,6 +2,7 @@ import ( "bytes" + "os" "testing" ) @@ -59,6 +60,10 @@ } func TestFrameWriteTooLong(t *testing.T) { + if os.Getenv("TRAVIS") == "true" { + t.Skip("skipping test in travis due to memory pressure with the race detecor") + } + w := &bytes.Buffer{} framer := newFramer(nil, w, nil, 2) @@ -71,6 +76,10 @@ } func TestFrameReadTooLong(t *testing.T) { + if os.Getenv("TRAVIS") == "true" { + t.Skip("skipping test in travis due to memory pressure with the race detecor") + } + r := &bytes.Buffer{} r.Write(make([]byte, maxFrameSize+1)) // write a new header right after this frame to verify that we can read it diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/.github/issue_template.md golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/.github/issue_template.md --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/.github/issue_template.md 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/.github/issue_template.md 2019-11-02 13:15:23.000000000 +0000 @@ -16,7 +16,7 @@ --- -If you are having connectivy related issues please share the following additional information +If you are having connectivity related issues please share the following additional information ### Describe your Cassandra cluster please provide the following information diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/.gitignore golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/.gitignore --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/.gitignore 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/.gitignore 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,5 @@ +gocql-fuzz +fuzz-corpus +fuzz-work +gocql.test +.idea diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/go.mod golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/go.mod --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/go.mod 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/go.mod 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,13 @@ +module github.com/gocql/gocql + +require ( + github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect + github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect + github.com/golang/snappy v0.0.0-20170215233205-553a64147049 + github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed + github.com/kr/pretty v0.1.0 // indirect + github.com/stretchr/testify v1.3.0 // indirect + gopkg.in/inf.v0 v0.9.1 +) + +go 1.13 diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/go.sum golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/go.sum --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/go.sum 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/go.sum 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,22 @@ +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk= +github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/helpers.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/helpers.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/helpers.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/helpers.go 2019-11-02 13:15:23.000000000 +0000 @@ -7,6 +7,7 @@ import ( "fmt" "math/big" + "net" "reflect" "strings" "time" @@ -25,6 +26,8 @@ return reflect.TypeOf(*new(string)) case TypeBigInt, TypeCounter: return reflect.TypeOf(*new(int64)) + case TypeTime: + return reflect.TypeOf(*new(time.Duration)) case TypeTimestamp: return reflect.TypeOf(*new(time.Time)) case TypeBlob: @@ -59,6 +62,8 @@ return reflect.TypeOf(make(map[string]interface{})) case TypeDate: return reflect.TypeOf(*new(time.Time)) + case TypeDuration: + return reflect.TypeOf(*new(Duration)) default: return nil } @@ -68,7 +73,7 @@ return reflect.Indirect(reflect.ValueOf(i)).Interface() } -func getCassandraType(name string) Type { +func getCassandraBaseType(name string) Type { switch name { case "ascii": return TypeAscii @@ -88,12 +93,18 @@ return TypeFloat case "int": return TypeInt + case "tinyint": + return TypeTinyInt + case "time": + return TypeTime case "timestamp": return TypeTimestamp case "uuid": return TypeUUID - case "varchar", "text": + case "varchar": return TypeVarchar + case "text": + return TypeText case "varint": return TypeVarint case "timeuuid": @@ -109,19 +120,97 @@ case "TupleType": return TypeTuple default: - if strings.HasPrefix(name, "set") { - return TypeSet - } else if strings.HasPrefix(name, "list") { - return TypeList - } else if strings.HasPrefix(name, "map") { - return TypeMap - } else if strings.HasPrefix(name, "tuple") { - return TypeTuple - } return TypeCustom } } +func getCassandraType(name string) TypeInfo { + if strings.HasPrefix(name, "frozen<") { + return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<")) + } else if strings.HasPrefix(name, "set<") { + return CollectionType{ + NativeType: NativeType{typ: TypeSet}, + Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<")), + } + } else if strings.HasPrefix(name, "list<") { + return CollectionType{ + NativeType: NativeType{typ: TypeList}, + Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<")), + } + } else if strings.HasPrefix(name, "map<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + if len(names) != 2 { + Logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) + return NativeType{ + typ: TypeCustom, + } + } + return CollectionType{ + NativeType: NativeType{typ: TypeMap}, + Key: getCassandraType(names[0]), + Elem: getCassandraType(names[1]), + } + } else if strings.HasPrefix(name, "tuple<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + types := make([]TypeInfo, len(names)) + + for i, name := range names { + types[i] = getCassandraType(name) + } + + return TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: types, + } + } else { + return NativeType{ + typ: getCassandraBaseType(name), + } + } +} + +func splitCompositeTypes(name string) []string { + if !strings.Contains(name, "<") { + return strings.Split(name, ", ") + } + var parts []string + lessCount := 0 + segment := "" + for _, char := range name { + if char == ',' && lessCount == 0 { + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + segment = "" + continue + } + segment += string(char) + if char == '<' { + lessCount++ + } else if char == '>' { + lessCount-- + } + } + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + return parts +} + +func apacheToCassandraType(t string) string { + t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) + t = strings.Replace(t, "(", "<", -1) + t = strings.Replace(t, ")", ">", -1) + types := strings.FieldsFunc(t, func(r rune) bool { + return r == '<' || r == '>' || r == ',' + }) + for _, typ := range types { + t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) + } + // This is done so it exactly matches what Cassandra returns + return strings.Replace(t, ",", ", ", -1) +} + func getApacheCassandraType(class string) Type { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "AsciiType": @@ -146,6 +235,8 @@ return TypeSmallInt case "ByteType": return TypeTinyInt + case "TimeType": + return TypeTime case "DateType", "TimestampType": return TypeTimestamp case "UUIDType", "LexicalUUIDType": @@ -166,6 +257,8 @@ return TypeSet case "TupleType": return TypeTuple + case "DurationType": + return TypeDuration default: return TypeCustom } @@ -205,30 +298,43 @@ return RowData{}, iter.err } - columns := make([]string, 0) - values := make([]interface{}, 0) + columns := make([]string, 0, len(iter.Columns())) + values := make([]interface{}, 0, len(iter.Columns())) for _, column := range iter.Columns() { - - switch c := column.TypeInfo.(type) { - case TupleTypeInfo: + if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { + val := column.TypeInfo.New() + columns = append(columns, column.Name) + values = append(values, val) + } else { for i, elem := range c.Elems { columns = append(columns, TupleColumnName(column.Name, i)) values = append(values, elem.New()) } - default: - val := column.TypeInfo.New() - columns = append(columns, column.Name) - values = append(values, val) } } + rowData := RowData{ Columns: columns, Values: values, } + return rowData, nil } +// TODO(zariel): is it worth exporting this? +func (iter *Iter) rowMap() (map[string]interface{}, error) { + if iter.err != nil { + return nil, iter.err + } + + rowData, _ := iter.RowData() + iter.Scan(rowData.Values...) + m := make(map[string]interface{}, len(rowData.Columns)) + rowData.rowMap(m) + return m, nil +} + // SliceMap is a helper function to make the API easier to use // returns the data from the query in the form of []map[string]interface{} func (iter *Iter) SliceMap() ([]map[string]interface{}, error) { @@ -240,7 +346,7 @@ rowData, _ := iter.RowData() dataToReturn := make([]map[string]interface{}, 0) for iter.Scan(rowData.Values...) { - m := make(map[string]interface{}) + m := make(map[string]interface{}, len(rowData.Columns)) rowData.rowMap(m) dataToReturn = append(dataToReturn, m) } @@ -314,3 +420,13 @@ copy(b, p) return b } + +var failDNS = false + +func LookupIP(host string) ([]net.IP, error) { + if failDNS { + return nil, &net.DNSError{} + } + return net.LookupIP(host) + +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/helpers_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/helpers_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/helpers_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/helpers_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,186 @@ +package gocql + +import ( + "reflect" + "testing" +) + +func TestGetCassandraType_Set(t *testing.T) { + typ := getCassandraType("set") + set, ok := typ.(CollectionType) + if !ok { + t.Fatalf("expected CollectionType got %T", typ) + } else if set.typ != TypeSet { + t.Fatalf("expected type %v got %v", TypeSet, set.typ) + } + + inner, ok := set.Elem.(NativeType) + if !ok { + t.Fatalf("expected to get NativeType got %T", set.Elem) + } else if inner.typ != TypeText { + t.Fatalf("expected to get %v got %v for set value", TypeText, set.typ) + } +} + +func TestGetCassandraType(t *testing.T) { + tests := []struct { + input string + exp TypeInfo + }{ + { + "set", CollectionType{ + NativeType: NativeType{typ: TypeSet}, + + Elem: NativeType{typ: TypeText}, + }, + }, + { + "map", CollectionType{ + NativeType: NativeType{typ: TypeMap}, + + Key: NativeType{typ: TypeText}, + Elem: NativeType{typ: TypeVarchar}, + }, + }, + { + "list", CollectionType{ + NativeType: NativeType{typ: TypeList}, + Elem: NativeType{typ: TypeInt}, + }, + }, + { + "tuple", TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeInt}, + NativeType{typ: TypeText}, + }, + }, + }, + { + "frozen>>>>>", CollectionType{ + NativeType: NativeType{typ: TypeMap}, + + Key: NativeType{typ: TypeText}, + Elem: CollectionType{ + NativeType: NativeType{typ: TypeList}, + Elem: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeInt}, + }, + }, + }, + }, + }, + { + "frozen>>>>>, frozen>>>>>, frozen>>>>>>>", + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + NativeType{typ: TypeText}, + CollectionType{ + NativeType: NativeType{typ: TypeList}, + Elem: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeInt}, + }, + }, + }, + }, + }, + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + NativeType{typ: TypeText}, + CollectionType{ + NativeType: NativeType{typ: TypeList}, + Elem: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeInt}, + }, + }, + }, + }, + }, + CollectionType{ + NativeType: NativeType{typ: TypeMap}, + Key: NativeType{typ: TypeText}, + Elem: CollectionType{ + NativeType: NativeType{typ: TypeList}, + Elem: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeInt}, + }, + }, + }, + }, + }, + }, + }, + { + "frozen>, int, frozen>>>", TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeInt}, + }, + }, + NativeType{typ: TypeInt}, + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeInt}, + }, + }, + }, + }, + }, + { + "frozen>, int>>", CollectionType{ + NativeType: NativeType{typ: TypeMap}, + + Key: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeInt}, + }, + }, + Elem: NativeType{typ: TypeInt}, + }, + }, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := getCassandraType(test.input) + + // TODO(zariel): define an equal method on the types? + if !reflect.DeepEqual(got, test.exp) { + t.Fatalf("expected %v got %v", test.exp, got) + } + }) + } +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/host_source_gen.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/host_source_gen.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/host_source_gen.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/host_source_gen.go 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,45 @@ +// +build genhostinfo + +package main + +import ( + "fmt" + "reflect" + "sync" + + "github.com/gocql/gocql" +) + +func gen(clause, field string) { + fmt.Printf("if h.%s == %s {\n", field, clause) + fmt.Printf("\th.%s = from.%s\n", field, field) + fmt.Println("}") +} + +func main() { + t := reflect.ValueOf(&gocql.HostInfo{}).Elem().Type() + mu := reflect.TypeOf(sync.RWMutex{}) + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Type == mu { + continue + } + + switch f.Type.Kind() { + case reflect.Slice: + gen("nil", f.Name) + case reflect.String: + gen(`""`, f.Name) + case reflect.Int: + gen("0", f.Name) + case reflect.Struct: + gen("("+f.Type.Name()+"{})", f.Name) + case reflect.Bool, reflect.Int32: + continue + default: + panic(fmt.Sprintf("unknown field: %s", f)) + } + } + +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/host_source.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/host_source.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/host_source.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/host_source.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,6 +1,7 @@ package gocql import ( + "context" "errors" "fmt" "net" @@ -73,16 +74,25 @@ } func (c cassVersion) Before(major, minor, patch int) bool { - if c.Major > major { - return true - } else if c.Minor > minor { - return true - } else if c.Patch > patch { + // We're comparing us (cassVersion) with the provided version (major, minor, patch) + // We return true if our version is lower (comes before) than the provided one. + if c.Major < major { return true + } else if c.Major == major { + if c.Minor < minor { + return true + } else if c.Minor == minor && c.Patch < patch { + return true + } + } return false } +func (c cassVersion) AtLeast(major, minor, patch int) bool { + return !c.Before(major, minor, patch) +} + func (c cassVersion) String() string { return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch) } @@ -100,6 +110,7 @@ // TODO(zariel): reduce locking maybe, not all values will change, but to ensure // that we are thread safe use a mutex to access all fields. mu sync.RWMutex + hostname string peer net.IP broadcastAddress net.IP listenAddress net.IP @@ -117,6 +128,7 @@ clusterName string version cassVersion state nodeState + schemaVersion string tokens []string } @@ -183,6 +195,7 @@ } func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo { + // TODO(zariel): should this not be exported? h.mu.Lock() defer h.mu.Unlock() h.connectAddress = address @@ -215,8 +228,9 @@ func (h *HostInfo) DataCenter() string { h.mu.RLock() - defer h.mu.RUnlock() - return h.dataCenter + dc := h.dataCenter + h.mu.RUnlock() + return dc } func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo { @@ -228,8 +242,9 @@ func (h *HostInfo) Rack() string { h.mu.RLock() - defer h.mu.RUnlock() - return h.rack + rack := h.rack + h.mu.RUnlock() + return rack } func (h *HostInfo) setRack(rack string) *HostInfo { @@ -335,28 +350,87 @@ } func (h *HostInfo) update(from *HostInfo) { + if h == from { + return + } + h.mu.Lock() defer h.mu.Unlock() - h.tokens = from.tokens - h.version = from.version - h.hostId = from.hostId - h.dataCenter = from.dataCenter + from.mu.RLock() + defer from.mu.RUnlock() + + // autogenerated do not update + if h.peer == nil { + h.peer = from.peer + } + if h.broadcastAddress == nil { + h.broadcastAddress = from.broadcastAddress + } + if h.listenAddress == nil { + h.listenAddress = from.listenAddress + } + if h.rpcAddress == nil { + h.rpcAddress = from.rpcAddress + } + if h.preferredIP == nil { + h.preferredIP = from.preferredIP + } + if h.connectAddress == nil { + h.connectAddress = from.connectAddress + } + if h.port == 0 { + h.port = from.port + } + if h.dataCenter == "" { + h.dataCenter = from.dataCenter + } + if h.rack == "" { + h.rack = from.rack + } + if h.hostId == "" { + h.hostId = from.hostId + } + if h.workload == "" { + h.workload = from.workload + } + if h.dseVersion == "" { + h.dseVersion = from.dseVersion + } + if h.partitioner == "" { + h.partitioner = from.partitioner + } + if h.clusterName == "" { + h.clusterName = from.clusterName + } + if h.version == (cassVersion{}) { + h.version = from.version + } + if h.tokens == nil { + h.tokens = from.tokens + } } func (h *HostInfo) IsUp() bool { return h != nil && h.State() == NodeUp } +func (h *HostInfo) HostnameAndPort() string { + if h.hostname == "" { + h.hostname = h.ConnectAddress().String() + } + return net.JoinHostPort(h.hostname, strconv.Itoa(h.port)) +} + func (h *HostInfo) String() string { h.mu.RLock() defer h.mu.RUnlock() connectAddr, source := h.connectAddressLocked() - return fmt.Sprintf("[HostInfo connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+ + return fmt.Sprintf("[HostInfo hostname=%q connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+ "preferred_ip=%q connect_addr=%q connect_addr_source=%q "+ "port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]", - h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP, + h.hostname, h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP, connectAddr, source, h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens)) } @@ -366,7 +440,6 @@ session *Session mu sync.Mutex prevHosts []*HostInfo - localHost *HostInfo prevPartitioner string } @@ -388,15 +461,11 @@ // Given a map that represents a row from either system.local or system.peers // return as much information as we can in *HostInfo -func (r *ringDescriber) hostInfoFromMap(row map[string]interface{}) (*HostInfo, error) { +func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) { const assertErrorMsg = "Assertion failed for %s" var ok bool // Default to our connected port if the cluster doesn't have port information - host := HostInfo{ - port: r.session.cfg.Port, - } - for key, value := range row { switch key { case "data_center": @@ -481,91 +550,62 @@ if !ok { return nil, fmt.Errorf(assertErrorMsg, "dse_version") } + case "schema_version": + schemaVersion, ok := value.(UUID) + if !ok { + return nil, fmt.Errorf(assertErrorMsg, "schema_version") + } + host.schemaVersion = schemaVersion.String() } // TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete // Not sure what the port field will be called until the JIRA issue is complete } - return &host, nil -} - -// Ask the control node for it's local host information -func (r *ringDescriber) GetLocalHostInfo() (*HostInfo, error) { - it := r.session.control.query("SELECT * FROM system.local WHERE key='local'") - if it == nil { - return nil, errors.New("Attempted to query 'system.local' on a closed control connection") - } - host, err := r.extractHostInfo(it) - if err != nil { - return nil, err - } - - if host.invalidConnectAddr() { - host.SetConnectAddress(r.session.control.GetHostInfo().ConnectAddress()) - } + ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port) + host.connectAddress = ip + host.port = port return host, nil } -// Given an ip address and port, return a peer that matched the ip address -func (r *ringDescriber) GetPeerHostInfo(ip net.IP, port int) (*HostInfo, error) { - it := r.session.control.query("SELECT * FROM system.peers WHERE peer=?", ip) - if it == nil { - return nil, errors.New("Attempted to query 'system.peers' on a closed control connection") - } - return r.extractHostInfo(it) -} - -func (r *ringDescriber) extractHostInfo(it *Iter) (*HostInfo, error) { - row := make(map[string]interface{}) - - // expect only 1 row - it.MapScan(row) - if err := it.Close(); err != nil { - return nil, err - } - - // extract all available info about the host - return r.hostInfoFromMap(row) -} - // Ask the control node for host info on all it's known peers -func (r *ringDescriber) GetClusterPeerInfo() ([]*HostInfo, error) { +func (r *ringDescriber) getClusterPeerInfo() ([]*HostInfo, error) { var hosts []*HostInfo + iter := r.session.control.withConnHost(func(ch *connHost) *Iter { + hosts = append(hosts, ch.host) + return ch.conn.query(context.TODO(), "SELECT * FROM system.peers") + }) - // Ask the node for a list of it's peers - it := r.session.control.query("SELECT * FROM system.peers") - if it == nil { - return nil, errors.New("Attempted to query 'system.peers' on a closed connection") + if iter == nil { + return nil, errNoControl } - for { - row := make(map[string]interface{}) - if !it.MapScan(row) { - break - } + rows, err := iter.SliceMap() + if err != nil { + // TODO(zariel): make typed error + return nil, fmt.Errorf("unable to fetch peer host info: %s", err) + } + + for _, row := range rows { // extract all available info about the peer - host, err := r.hostInfoFromMap(row) + host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}) if err != nil { return nil, err - } - - // If it's not a valid peer - if !r.IsValidPeer(host) { - Logger.Printf("Found invalid peer '%+v' "+ + } else if !isValidPeer(host) { + // If it's not a valid peer + Logger.Printf("Found invalid peer '%s' "+ "Likely due to a gossip or snitch issue, this host will be ignored", host) continue } + hosts = append(hosts, host) } - if it.err != nil { - return nil, fmt.Errorf("while scanning 'system.peers' table: %s", it.err) - } + return hosts, nil } // Return true if the host is a valid peer -func (r *ringDescriber) IsValidPeer(host *HostInfo) bool { +func isValidPeer(host *HostInfo) bool { return !(len(host.RPCAddress()) == 0 || host.hostId == "" || host.dataCenter == "" || @@ -578,84 +618,58 @@ r.mu.Lock() defer r.mu.Unlock() - // Update the localHost info with data from the connected host - localHost, err := r.GetLocalHostInfo() + hosts, err := r.getClusterPeerInfo() if err != nil { return r.prevHosts, r.prevPartitioner, err - } else if localHost.invalidConnectAddr() { - panic(fmt.Sprintf("unable to get localhost connect address: %v", localHost)) } - // Update our list of hosts by querying the cluster - hosts, err := r.GetClusterPeerInfo() - if err != nil { - return r.prevHosts, r.prevPartitioner, err + var partitioner string + if len(hosts) > 0 { + partitioner = hosts[0].Partitioner() } - hosts = append(hosts, localHost) - - // Filter the hosts if filter is provided - filteredHosts := hosts - if r.session.cfg.HostFilter != nil { - filteredHosts = filteredHosts[:0] - for _, host := range hosts { - if r.session.cfg.HostFilter.Accept(host) { - filteredHosts = append(filteredHosts, host) - } - } - } - - r.prevHosts = filteredHosts - r.prevPartitioner = localHost.partitioner - r.localHost = localHost - - return filteredHosts, localHost.partitioner, nil + return hosts, partitioner, nil } // Given an ip/port return HostInfo for the specified ip/port -func (r *ringDescriber) GetHostInfo(ip net.IP, port int) (*HostInfo, error) { - // TODO(thrawn01): Is IgnorePeerAddr still useful now that we have DisableInitialHostLookup? - // TODO(thrawn01): should we also check for DisableInitialHostLookup and return if true? - - // Ignore the port and connect address and use the address/port we already have - if r.session.control == nil || r.session.cfg.IgnorePeerAddr { - return &HostInfo{connectAddress: ip, port: port}, nil - } +func (r *ringDescriber) getHostInfo(ip net.IP, port int) (*HostInfo, error) { + var host *HostInfo + iter := r.session.control.withConnHost(func(ch *connHost) *Iter { + if ch.host.ConnectAddress().Equal(ip) { + host = ch.host + return nil + } - // Attempt to get the host info for our control connection - controlHost := r.session.control.GetHostInfo() - if controlHost == nil { - return nil, errors.New("invalid control connection") - } + return ch.conn.query(context.TODO(), "SELECT * FROM system.peers") + }) - var ( - host *HostInfo - err error - ) + if iter != nil { + rows, err := iter.SliceMap() + if err != nil { + return nil, err + } - // If we are asking about the same node our control connection has a connection too - if controlHost.ConnectAddress().Equal(ip) { - host, err = r.GetLocalHostInfo() - } else { - host, err = r.GetPeerHostInfo(ip, port) - } + for _, row := range rows { + h, err := r.session.hostInfoFromMap(row, &HostInfo{port: port}) + if err != nil { + return nil, err + } - // No host was found matching this ip/port - if err != nil { - return nil, err - } + if h.ConnectAddress().Equal(ip) { + host = h + break + } + } - if controlHost.ConnectAddress().Equal(ip) { - // Always respect the provided control node address and disregard the ip address - // the cassandra node provides. We do this as we are already connected and have a - // known valid ip address. This insulates gocql from client connection issues stemming - // from node misconfiguration. For instance when a node is run from a container, by - // default the node will report its ip address as 127.0.0.1 which is typically invalid. - host.SetConnectAddress(ip) + if host == nil { + return nil, errors.New("host not found in peers table") + } } - if host.invalidConnectAddr() { - return nil, fmt.Errorf("host ConnectAddress invalid: %v", host) + if host == nil { + return nil, errors.New("unable to fetch host info: invalid control connection") + } else if host.invalidConnectAddr() { + return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", ip, host) } return host, nil @@ -675,6 +689,10 @@ // TODO: move this to session for _, h := range hosts { + if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) { + continue + } + if host, ok := r.session.ring.addHostIfMissing(h); !ok { r.session.pool.addHost(h) r.session.policy.AddHost(h) diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/host_source_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/host_source_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/host_source_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/host_source_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,9 +1,8 @@ -// +build all integration +// +build all cassandra package gocql import ( - "fmt" "net" "testing" ) @@ -39,10 +38,11 @@ {cassVersion{1, 0, 0}, 0, 1, 0}, {cassVersion{0, 1, 0}, 0, 0, 1}, + {cassVersion{4, 1, 0}, 3, 1, 2}, } for i, test := range tests { - if !test.version.Before(test.major, test.minor, test.patch) { + if test.version.Before(test.major, test.minor, test.patch) { t.Errorf("%d: expected v%d.%d.%d to be before %v", i, test.major, test.minor, test.patch, test.version) } } @@ -50,7 +50,6 @@ } func TestIsValidPeer(t *testing.T) { - ring := ringDescriber{} host := &HostInfo{ rpcAddress: net.ParseIP("0.0.0.0"), rack: "myRack", @@ -59,52 +58,16 @@ tokens: []string{"0", "1"}, } - if !ring.IsValidPeer(host) { + if !isValidPeer(host) { t.Errorf("expected %+v to be a valid peer", host) } host.rack = "" - if ring.IsValidPeer(host) { + if isValidPeer(host) { t.Errorf("expected %+v to NOT be a valid peer", host) } } -func TestGetHosts(t *testing.T) { - cluster := createCluster() - session := createSessionFromCluster(cluster, t) - - hosts, partitioner, err := session.hostSource.GetHosts() - - assertTrue(t, "err == nil", err == nil) - assertTrue(t, "len(hosts) == 3", len(hosts) == 3) - assertTrue(t, "len(partitioner) != 0", len(partitioner) != 0) - -} - -func TestGetHostsWithFilter(t *testing.T) { - filterHostIP := net.ParseIP("127.0.0.3") - cluster := createCluster() - - // Filter to remove one of the localhost nodes - cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool { - if host.ConnectAddress().Equal(filterHostIP) { - return false - } - return true - }) - session := createSessionFromCluster(cluster, t) - - hosts, partitioner, err := session.hostSource.GetHosts() - assertTrue(t, "err == nil", err == nil) - assertTrue(t, "len(hosts) == 2", len(hosts) == 2) - assertTrue(t, "len(partitioner) != 0", len(partitioner) != 0) - for _, host := range hosts { - if host.ConnectAddress().Equal(filterHostIP) { - t.Fatal(fmt.Sprintf("Did not expect to see '%q' in host list", filterHostIP)) - } - } -} - func TestHostInfo_ConnectAddress(t *testing.T) { var localhost = net.IPv4(127, 0, 0, 1) tests := []struct { diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/install_test_deps.sh golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/install_test_deps.sh --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/install_test_deps.sh 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/install_test_deps.sh 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +# This is not supposed to be an error-prone script; just a convenience. + +# Install CCM +pip install --user cql PyYAML six +git clone https://github.com/pcmanus/ccm.git +pushd ccm +./setup.py install --user +popd + +if [ "$1" != "gocql/gocql" ]; then + USER=$(echo $1 | cut -f1 -d'/') + cd ../.. + mv ${USER} gocql +fi diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/integration.sh golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/integration.sh --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/integration.sh 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/integration.sh 2019-11-02 13:15:23.000000000 +0000 @@ -50,9 +50,11 @@ elif [[ $version == 2.2.* || $version == 3.0.* ]]; then proto=4 ccm updateconf 'enable_user_defined_functions: true' + export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" elif [[ $version == 3.*.* ]]; then - proto=4 + proto=5 ccm updateconf 'enable_user_defined_functions: true' + export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" fi sleep 1s @@ -64,7 +66,7 @@ local args="-gocql.timeout=60s -runssl -proto=$proto -rf=3 -clusterSize=$clusterSize -autowait=2000ms -compressor=snappy -gocql.cversion=$version -cluster=$(ccm liveset) ./..." - go test -v -tags unit + go test -v -tags unit -race if [ "$auth" = true ] then @@ -72,13 +74,19 @@ go test -run=TestAuthentication -tags "integration gocql_debug" -timeout=15s -runauth $args else sleep 1s - go test -tags "integration gocql_debug" -timeout=5m $args + go test -tags "cassandra gocql_debug" -timeout=5m -race $args + + ccm clear + ccm start --wait-for-binary-proto + sleep 1s + + go test -tags "integration gocql_debug" -timeout=5m -race $args ccm clear - ccm start + ccm start --wait-for-binary-proto sleep 1s - go test -tags "ccm gocql_debug" -timeout=5m $args + go test -tags "ccm gocql_debug" -timeout=5m -race $args fi ccm remove diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/integration_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/integration_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/integration_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/integration_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,190 @@ +// +build all integration + +package gocql + +// This file groups integration tests where Cassandra has to be set up with some special integration variables +import ( + "reflect" + "testing" + "time" +) + +// TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections +func TestAuthentication(t *testing.T) { + + if *flagProto < 2 { + t.Skip("Authentication is not supported with protocol < 2") + } + + if !*flagRunAuthTest { + t.Skip("Authentication is not configured in the target cluster") + } + + cluster := createCluster() + + cluster.Authenticator = PasswordAuthenticator{ + Username: "cassandra", + Password: "cassandra", + } + + session, err := cluster.CreateSession() + + if err != nil { + t.Fatalf("Authentication error: %s", err) + } + + session.Close() +} + +func TestGetHosts(t *testing.T) { + clusterHosts := getClusterHosts() + cluster := createCluster() + session := createSessionFromCluster(cluster, t) + + hosts, partitioner, err := session.hostSource.GetHosts() + + assertTrue(t, "err == nil", err == nil) + assertEqual(t, "len(hosts)", len(clusterHosts), len(hosts)) + assertTrue(t, "len(partitioner) != 0", len(partitioner) != 0) +} + +//TestRingDiscovery makes sure that you can autodiscover other cluster members when you seed a cluster config with just one node +func TestRingDiscovery(t *testing.T) { + clusterHosts := getClusterHosts() + cluster := createCluster() + cluster.Hosts = clusterHosts[:1] + + session := createSessionFromCluster(cluster, t) + defer session.Close() + + if *clusterSize > 1 { + // wait for autodiscovery to update the pool with the list of known hosts + time.Sleep(*flagAutoWait) + } + + session.pool.mu.RLock() + defer session.pool.mu.RUnlock() + size := len(session.pool.hostConnPools) + + if *clusterSize != size { + for p, pool := range session.pool.hostConnPools { + t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.ConnectAddress().String()) + + } + t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size) + } +} + +func TestWriteFailure(t *testing.T) { + cluster := createCluster() + createKeyspace(t, cluster, "test") + cluster.Keyspace = "test" + session, err := cluster.CreateSession() + if err != nil { + t.Fatal("create session:", err) + } + defer session.Close() + + if err := createTable(session, "CREATE TABLE test.test (id int,value int,PRIMARY KEY (id))"); err != nil { + t.Fatalf("failed to create table with error '%v'", err) + } + if err := session.Query(`INSERT INTO test.test (id, value) VALUES (1, 1)`).Exec(); err != nil { + errWrite, ok := err.(*RequestErrWriteFailure) + if ok { + if session.cfg.ProtoVersion >= 5 { + // ErrorMap should be filled with some hosts that should've errored + if len(errWrite.ErrorMap) == 0 { + t.Fatal("errWrite.ErrorMap should have some failed hosts but it didn't have any") + } + } else { + // Map doesn't get filled for V4 + if len(errWrite.ErrorMap) != 0 { + t.Fatal("errWrite.ErrorMap should have length 0, it's: ", len(errWrite.ErrorMap)) + } + } + } else { + t.Fatal("error should be RequestErrWriteFailure, it's: ", errWrite) + } + } else { + t.Fatal("a write fail error should have happened when querying test keyspace") + } + + if err = session.Query("DROP KEYSPACE test").Exec(); err != nil { + t.Fatal(err) + } +} + +func TestCustomPayloadMessages(t *testing.T) { + cluster := createCluster() + session := createSessionFromCluster(cluster, t) + defer session.Close() + + if err := createTable(session, "CREATE TABLE gocql_test.testCustomPayloadMessages (id int, value int, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + // QueryMessage + var customPayload = map[string][]byte{"a": []byte{10, 20}, "b": []byte{20, 30}} + query := session.Query("SELECT id FROM testCustomPayloadMessages where id = ?", 42).Consistency(One).CustomPayload(customPayload) + iter := query.Iter() + rCustomPayload := iter.GetCustomPayload() + if !reflect.DeepEqual(customPayload, rCustomPayload) { + t.Fatal("The received custom payload should match the sent") + } + iter.Close() + + // Insert query + query = session.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)").Consistency(One).CustomPayload(customPayload) + iter = query.Iter() + rCustomPayload = iter.GetCustomPayload() + if !reflect.DeepEqual(customPayload, rCustomPayload) { + t.Fatal("The received custom payload should match the sent") + } + iter.Close() + + // Batch Message + b := session.NewBatch(LoggedBatch) + b.CustomPayload = customPayload + b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)") + if err := session.ExecuteBatch(b); err != nil { + t.Fatalf("query failed. %v", err) + } +} + +func TestCustomPayloadValues(t *testing.T) { + cluster := createCluster() + session := createSessionFromCluster(cluster, t) + defer session.Close() + + if err := createTable(session, "CREATE TABLE gocql_test.testCustomPayloadValues (id int, value int, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + values := []map[string][]byte{map[string][]byte{"a": []byte{10, 20}, "b": []byte{20, 30}}, nil, map[string][]byte{"a": []byte{10, 20}, "b": nil}} + + for _, customPayload := range values { + query := session.Query("SELECT id FROM testCustomPayloadValues where id = ?", 42).Consistency(One).CustomPayload(customPayload) + iter := query.Iter() + rCustomPayload := iter.GetCustomPayload() + if !reflect.DeepEqual(customPayload, rCustomPayload) { + t.Fatal("The received custom payload should match the sent") + } + } +} + +func TestUDF(t *testing.T) { + session := createSession(t) + defer session.Close() + if session.cfg.ProtoVersion < 4 { + t.Skip("skipping UDF support on proto < 4") + } + + const query = `CREATE OR REPLACE FUNCTION uniq(state set, val text) + CALLED ON NULL INPUT RETURNS set LANGUAGE java + AS 'state.add(val); return state;'` + + err := session.Query(query).Exec() + if err != nil { + t.Fatal(err) + } +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/murmur/murmur_appengine.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/murmur/murmur_appengine.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/murmur/murmur_appengine.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/murmur/murmur_appengine.go 2019-11-02 13:15:23.000000000 +0000 @@ -4,134 +4,8 @@ import "encoding/binary" -func Murmur3H1(data []byte) uint64 { - length := len(data) - - var h1, h2, k1, k2 uint64 - - const ( - c1 = 0x87c37b91114253d5 - c2 = 0x4cf5ad432745937f - ) - - // body - nBlocks := length / 16 - for i := 0; i < nBlocks; i++ { - // block := (*[2]uint64)(unsafe.Pointer(&data[i*16])) - - k1 = binary.LittleEndian.Uint64(data[i*16:]) - k2 = binary.LittleEndian.Uint64(data[(i*16)+8:]) - - k1 *= c1 - k1 = (k1 << 31) | (k1 >> 33) // ROTL64(k1, 31) - k1 *= c2 - h1 ^= k1 - - h1 = (h1 << 27) | (h1 >> 37) // ROTL64(h1, 27) - h1 += h2 - h1 = h1*5 + 0x52dce729 - - k2 *= c2 - k2 = (k2 << 33) | (k2 >> 31) // ROTL64(k2, 33) - k2 *= c1 - h2 ^= k2 - - h2 = (h2 << 31) | (h2 >> 33) // ROTL64(h2, 31) - h2 += h1 - h2 = h2*5 + 0x38495ab5 - } - - // tail - tail := data[nBlocks*16:] - k1 = 0 - k2 = 0 - switch length & 15 { - case 15: - k2 ^= uint64(tail[14]) << 48 - fallthrough - case 14: - k2 ^= uint64(tail[13]) << 40 - fallthrough - case 13: - k2 ^= uint64(tail[12]) << 32 - fallthrough - case 12: - k2 ^= uint64(tail[11]) << 24 - fallthrough - case 11: - k2 ^= uint64(tail[10]) << 16 - fallthrough - case 10: - k2 ^= uint64(tail[9]) << 8 - fallthrough - case 9: - k2 ^= uint64(tail[8]) - - k2 *= c2 - k2 = (k2 << 33) | (k2 >> 31) // ROTL64(k2, 33) - k2 *= c1 - h2 ^= k2 - - fallthrough - case 8: - k1 ^= uint64(tail[7]) << 56 - fallthrough - case 7: - k1 ^= uint64(tail[6]) << 48 - fallthrough - case 6: - k1 ^= uint64(tail[5]) << 40 - fallthrough - case 5: - k1 ^= uint64(tail[4]) << 32 - fallthrough - case 4: - k1 ^= uint64(tail[3]) << 24 - fallthrough - case 3: - k1 ^= uint64(tail[2]) << 16 - fallthrough - case 2: - k1 ^= uint64(tail[1]) << 8 - fallthrough - case 1: - k1 ^= uint64(tail[0]) - - k1 *= c1 - k1 = (k1 << 31) | (k1 >> 33) // ROTL64(k1, 31) - k1 *= c2 - h1 ^= k1 - } - - h1 ^= uint64(length) - h2 ^= uint64(length) - - h1 += h2 - h2 += h1 - - // finalizer - const ( - fmix1 = 0xff51afd7ed558ccd - fmix2 = 0xc4ceb9fe1a85ec53 - ) - - // fmix64(h1) - h1 ^= h1 >> 33 - h1 *= fmix1 - h1 ^= h1 >> 33 - h1 *= fmix2 - h1 ^= h1 >> 33 - - // fmix64(h2) - h2 ^= h2 >> 33 - h2 *= fmix1 - h2 ^= h2 >> 33 - h2 *= fmix2 - h2 ^= h2 >> 33 - - h1 += h2 - // the following is extraneous since h2 is discarded - // h2 += h1 - - return h1 +func getBlock(data []byte, n int) (int64, int64) { + k1 := binary.LittleEndian.Int64(data[n*16:]) + k2 := binary.LittleEndian.Int64(data[(n*16)+8:]) + return k1, k2 } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/murmur/murmur.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/murmur/murmur.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/murmur/murmur.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/murmur/murmur.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,44 +1,57 @@ -// +build !appengine - package murmur -import ( - "unsafe" +const ( + c1 int64 = -8663945395140668459 // 0x87c37b91114253d5 + c2 int64 = 5545529020109919103 // 0x4cf5ad432745937f + fmix1 int64 = -49064778989728563 // 0xff51afd7ed558ccd + fmix2 int64 = -4265267296055464877 // 0xc4ceb9fe1a85ec53 ) -func Murmur3H1(data []byte) uint64 { - length := len(data) +func fmix(n int64) int64 { + // cast to unsigned for logical right bitshift (to match C* MM3 implementation) + n ^= int64(uint64(n) >> 33) + n *= fmix1 + n ^= int64(uint64(n) >> 33) + n *= fmix2 + n ^= int64(uint64(n) >> 33) + + return n +} + +func block(p byte) int64 { + return int64(int8(p)) +} + +func rotl(x int64, r uint8) int64 { + // cast to unsigned for logical right bitshift (to match C* MM3 implementation) + return (x << r) | (int64)((uint64(x) >> (64 - r))) +} - var h1, h2, k1, k2 uint64 +func Murmur3H1(data []byte) int64 { + length := len(data) - const ( - c1 = 0x87c37b91114253d5 - c2 = 0x4cf5ad432745937f - ) + var h1, h2, k1, k2 int64 // body nBlocks := length / 16 for i := 0; i < nBlocks; i++ { - block := (*[2]uint64)(unsafe.Pointer(&data[i*16])) - - k1 = block[0] - k2 = block[1] + k1, k2 = getBlock(data, i) k1 *= c1 - k1 = (k1 << 31) | (k1 >> 33) // ROTL64(k1, 31) + k1 = rotl(k1, 31) k1 *= c2 h1 ^= k1 - h1 = (h1 << 27) | (h1 >> 37) // ROTL64(h1, 27) + h1 = rotl(h1, 27) h1 += h2 h1 = h1*5 + 0x52dce729 k2 *= c2 - k2 = (k2 << 33) | (k2 >> 31) // ROTL64(k2, 33) + k2 = rotl(k2, 33) k2 *= c1 h2 ^= k2 - h2 = (h2 << 31) | (h2 >> 33) // ROTL64(h2, 31) + h2 = rotl(h2, 31) h2 += h1 h2 = h2*5 + 0x38495ab5 } @@ -49,87 +62,70 @@ k2 = 0 switch length & 15 { case 15: - k2 ^= uint64(tail[14]) << 48 + k2 ^= block(tail[14]) << 48 fallthrough case 14: - k2 ^= uint64(tail[13]) << 40 + k2 ^= block(tail[13]) << 40 fallthrough case 13: - k2 ^= uint64(tail[12]) << 32 + k2 ^= block(tail[12]) << 32 fallthrough case 12: - k2 ^= uint64(tail[11]) << 24 + k2 ^= block(tail[11]) << 24 fallthrough case 11: - k2 ^= uint64(tail[10]) << 16 + k2 ^= block(tail[10]) << 16 fallthrough case 10: - k2 ^= uint64(tail[9]) << 8 + k2 ^= block(tail[9]) << 8 fallthrough case 9: - k2 ^= uint64(tail[8]) + k2 ^= block(tail[8]) k2 *= c2 - k2 = (k2 << 33) | (k2 >> 31) // ROTL64(k2, 33) + k2 = rotl(k2, 33) k2 *= c1 h2 ^= k2 fallthrough case 8: - k1 ^= uint64(tail[7]) << 56 + k1 ^= block(tail[7]) << 56 fallthrough case 7: - k1 ^= uint64(tail[6]) << 48 + k1 ^= block(tail[6]) << 48 fallthrough case 6: - k1 ^= uint64(tail[5]) << 40 + k1 ^= block(tail[5]) << 40 fallthrough case 5: - k1 ^= uint64(tail[4]) << 32 + k1 ^= block(tail[4]) << 32 fallthrough case 4: - k1 ^= uint64(tail[3]) << 24 + k1 ^= block(tail[3]) << 24 fallthrough case 3: - k1 ^= uint64(tail[2]) << 16 + k1 ^= block(tail[2]) << 16 fallthrough case 2: - k1 ^= uint64(tail[1]) << 8 + k1 ^= block(tail[1]) << 8 fallthrough case 1: - k1 ^= uint64(tail[0]) + k1 ^= block(tail[0]) k1 *= c1 - k1 = (k1 << 31) | (k1 >> 33) // ROTL64(k1, 31) + k1 = rotl(k1, 31) k1 *= c2 h1 ^= k1 } - h1 ^= uint64(length) - h2 ^= uint64(length) + h1 ^= int64(length) + h2 ^= int64(length) h1 += h2 h2 += h1 - // finalizer - const ( - fmix1 = 0xff51afd7ed558ccd - fmix2 = 0xc4ceb9fe1a85ec53 - ) - - // fmix64(h1) - h1 ^= h1 >> 33 - h1 *= fmix1 - h1 ^= h1 >> 33 - h1 *= fmix2 - h1 ^= h1 >> 33 - - // fmix64(h2) - h2 ^= h2 >> 33 - h2 *= fmix1 - h2 ^= h2 >> 33 - h2 *= fmix2 - h2 ^= h2 >> 33 + h1 = fmix(h1) + h2 = fmix(h2) h1 += h2 // the following is extraneous since h2 is discarded diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/murmur/murmur_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/murmur/murmur_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/murmur/murmur_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/murmur/murmur_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,10 +1,71 @@ package murmur import ( + "encoding/hex" + "fmt" "strconv" "testing" ) +func TestRotl(t *testing.T) { + tests := []struct { + in, rotate, exp int64 + }{ + {123456789, 33, 1060485742448345088}, + {-123456789, 33, -1060485733858410497}, + {-12345678987654, 33, 1756681988166642059}, + + {7210216203459776512, 31, -4287945813905642825}, + {2453826951392495049, 27, -2013042863942636044}, + {270400184080946339, 33, -3553153987756601583}, + {2060965185473694757, 31, 6290866853133484661}, + {3075794793055692309, 33, -3158909918919076318}, + {-6486402271863858009, 31, 405973038345868736}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%d >> %d", test.in, test.rotate), func(t *testing.T) { + if v := rotl(test.in, uint8(test.rotate)); v != test.exp { + t.Fatalf("expected %d got %d", test.exp, v) + } + }) + } +} + +func TestFmix(t *testing.T) { + tests := []struct { + in, exp int64 + }{ + {123456789, -8107560010088384378}, + {-123456789, -5252787026298255965}, + {-12345678987654, -1122383578793231303}, + {-1241537367799374202, 3388197556095096266}, + {-7566534940689533355, 4729783097411765989}, + } + + for _, test := range tests { + t.Run(strconv.Itoa(int(test.in)), func(t *testing.T) { + if v := fmix(test.in); v != test.exp { + t.Fatalf("expected %d got %d", test.exp, v) + } + }) + } + +} + +func TestMurmur3H1_CassandraSign(t *testing.T) { + key, err := hex.DecodeString("00104327529fb645dd00b883ec39ae448bb800000400066a6b00") + if err != nil { + t.Fatal(err) + } + h := Murmur3H1(key) + const exp int64 = -9223371632693506265 + + if h != exp { + t.Fatalf("expected %d got %d", exp, h) + } +} + // Test the implementation of murmur3 func TestMurmur3H1(t *testing.T) { // these examples are based on adding a index number to a sample string in @@ -50,8 +111,8 @@ // helper function for testing the murmur3 implementation func assertMurmur3H1(t *testing.T, data []byte, expected uint64) { actual := Murmur3H1(data) - if actual != expected { - t.Errorf("Expected h1 = %x for data = %x, but was %x", expected, data, actual) + if actual != int64(expected) { + t.Errorf("Expected h1 = %x for data = %x, but was %x", int64(expected), data, actual) } } @@ -66,8 +127,8 @@ b.RunParallel(func(pb *testing.PB) { for pb.Next() { h1 := Murmur3H1(data) - if h1 != 7627370222079200297 { - b.Fatalf("expected %d got %d", 7627370222079200297, h1) + if h1 != int64(7627370222079200297) { + b.Fatalf("expected %d got %d", int64(7627370222079200297), h1) } } }) diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/murmur/murmur_unsafe.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/murmur/murmur_unsafe.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/murmur/murmur_unsafe.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/murmur/murmur_unsafe.go 2019-11-02 13:15:23.000000000 +0000 @@ -0,0 +1,15 @@ +// +build !appengine + +package murmur + +import ( + "unsafe" +) + +func getBlock(data []byte, n int) (int64, int64) { + block := (*[2]int64)(unsafe.Pointer(&data[n*16])) + + k1 := block[0] + k2 := block[1] + return k1, k2 +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/streams/streams.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/streams/streams.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/internal/streams/streams.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/internal/streams/streams.go 2019-11-02 13:15:23.000000000 +0000 @@ -105,7 +105,7 @@ buf = append(buf, bitfmt(bits)...) buf = append(buf, ' ') } - return string(buf[:size-1 : size-1]) + return string(buf[: size-1 : size-1]) } func (s *IDGenerator) Clear(stream int) (inuse bool) { diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/marshal.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/marshal.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/marshal.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/marshal.go 2019-11-02 13:15:23.000000000 +0000 @@ -11,6 +11,7 @@ "fmt" "math" "math/big" + "math/bits" "net" "reflect" "strconv" @@ -81,7 +82,9 @@ return marshalDouble(info, value) case TypeDecimal: return marshalDecimal(info, value) - case TypeTimestamp, TypeTime: + case TypeTime: + return marshalTime(info, value) + case TypeTimestamp: return marshalTimestamp(info, value) case TypeList, TypeSet: return marshalList(info, value) @@ -99,6 +102,8 @@ return marshalUDT(info, value) case TypeDate: return marshalDate(info, value) + case TypeDuration: + return marshalDuration(info, value) } // detect protocol 2 UDT @@ -143,7 +148,9 @@ return unmarshalDouble(info, data, value) case TypeDecimal: return unmarshalDecimal(info, data, value) - case TypeTimestamp, TypeTime: + case TypeTime: + return unmarshalTime(info, data, value) + case TypeTimestamp: return unmarshalTimestamp(info, data, value) case TypeList, TypeSet: return unmarshalList(info, data, value) @@ -161,6 +168,8 @@ return unmarshalUDT(info, data, value) case TypeDate: return unmarshalDate(info, data, value) + case TypeDuration: + return unmarshalDuration(info, data, value) } // detect protocol 2 UDT @@ -232,12 +241,13 @@ return nil case *[]byte: if data != nil { - *v = copyBytes(data) + *v = append((*v)[:0], data...) } else { *v = nil } return nil } + rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) @@ -330,7 +340,7 @@ return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil - default: + case reflect.Ptr: if rv.IsNil() { return nil, nil } @@ -414,7 +424,7 @@ return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil - default: + case reflect.Ptr: if rv.IsNil() { return nil, nil } @@ -486,7 +496,7 @@ return nil, marshalErrorf("marshal int: value %d out of range", v) } return encInt(int32(v)), nil - default: + case reflect.Ptr: if rv.IsNil() { return nil, nil } @@ -703,9 +713,6 @@ return nil case *uint: unitVal := uint64(int64Val) - if ^uint(0) == math.MaxUint32 && unitVal > math.MaxUint32 { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v) - } switch info.Type() { case TypeInt: *v = uint(unitVal) & 0xFFFFFFFF @@ -714,6 +721,9 @@ case TypeTinyInt: *v = uint(unitVal) & 0xFF default: + if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v) + } *v = uint(unitVal) } return nil @@ -739,15 +749,17 @@ *v = int32(int64Val) return nil case *uint32: - if int64Val > math.MaxUint32 { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } switch info.Type() { + case TypeInt: + *v = uint32(int64Val) & 0xFFFFFFFF case TypeSmallInt: *v = uint32(int64Val) & 0xFFFF case TypeTinyInt: *v = uint32(int64Val) & 0xFF default: + if int64Val < 0 || int64Val > math.MaxUint32 { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } *v = uint32(int64Val) & 0xFFFFFFFF } return nil @@ -758,13 +770,15 @@ *v = int16(int64Val) return nil case *uint16: - if int64Val > math.MaxUint16 { - return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) - } switch info.Type() { + case TypeSmallInt: + *v = uint16(int64Val) & 0xFFFF case TypeTinyInt: *v = uint16(int64Val) & 0xFF default: + if int64Val < 0 || int64Val > math.MaxUint16 { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } *v = uint16(int64Val) & 0xFFFF } return nil @@ -775,7 +789,7 @@ *v = int8(int64Val) return nil case *uint8: - if int64Val > math.MaxUint8 { + if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = uint8(int64Val) & 0xFF @@ -823,34 +837,69 @@ rv.SetInt(int64Val) return nil case reflect.Uint: - if int64Val < 0 || (^uint(0) == math.MaxUint32 && int64Val > math.MaxUint32) { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", unitVal, rv.Type()) + } + rv.SetUint(unitVal) } - rv.SetUint(uint64(int64Val)) return nil case reflect.Uint64: - if int64Val < 0 { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + rv.SetUint(unitVal) } - rv.SetUint(uint64(int64Val)) return nil case reflect.Uint32: - if int64Val < 0 || int64Val > math.MaxUint32 { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if int64Val < 0 || int64Val > math.MaxUint32 { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) + } + rv.SetUint(unitVal & 0xFFFFFFFF) } - rv.SetUint(uint64(int64Val)) return nil case reflect.Uint16: - if int64Val < 0 || int64Val > math.MaxUint16 { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + unitVal := uint64(int64Val) + switch info.Type() { + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if int64Val < 0 || int64Val > math.MaxUint16 { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) + } + rv.SetUint(unitVal & 0xFFFF) } - rv.SetUint(uint64(int64Val)) return nil case reflect.Uint8: - if int64Val < 0 || int64Val > math.MaxUint8 { - return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) } - rv.SetUint(uint64(int64Val)) + rv.SetUint(uint64(int64Val) & 0xff) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) @@ -1084,6 +1133,30 @@ return nil } +func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int64: + return encBigInt(v), nil + case time.Duration: + return encBigInt(v.Nanoseconds()), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: @@ -1098,8 +1171,6 @@ } x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) return encBigInt(x), nil - case time.Duration: - return encBigInt(v.Nanoseconds()), nil } if value == nil { @@ -1114,6 +1185,31 @@ return nil, marshalErrorf("can not marshal %T into %s", value, info) } +func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *int64: + *v = decBigInt(data) + return nil + case *time.Duration: + *v = time.Duration(decBigInt(data)) + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Int64: + rv.SetInt(decBigInt(data)) + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -1131,8 +1227,6 @@ nsec := (x - sec*1000) * 1000000 *v = time.Unix(sec, nsec).In(time.UTC) return nil - case *time.Duration: - *v = time.Duration(decBigInt(data)) } rv := reflect.ValueOf(value) @@ -1206,10 +1300,129 @@ timestamp := (int64(current) - int64(origin)) * 86400000 *v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC) return nil + case *string: + if len(data) == 0 { + *v = "" + return nil + } + var origin uint32 = 1 << 31 + var current uint32 = binary.BigEndian.Uint32(data) + timestamp := (int64(current) - int64(origin)) * 86400000 + *v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC).Format("2006-01-02") + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int64: + return encVints(0, 0, v), nil + case time.Duration: + return encVints(0, 0, v.Nanoseconds()), nil + case string: + d, err := time.ParseDuration(v) + if err != nil { + return nil, err + } + return encVints(0, 0, d.Nanoseconds()), nil + case Duration: + return encVints(v.Months, v.Days, v.Nanoseconds), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *Duration: + if len(data) == 0 { + *v = Duration{ + Months: 0, + Days: 0, + Nanoseconds: 0, + } + return nil + } + months, days, nanos := decVints(data) + *v = Duration{ + Months: months, + Days: days, + Nanoseconds: nanos, + } + return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } +func decVints(data []byte) (int32, int32, int64) { + month, i := decVint(data) + days, j := decVint(data[i:]) + nanos, _ := decVint(data[i+j:]) + return int32(month), int32(days), nanos +} + +func decVint(data []byte) (int64, int) { + firstByte := data[0] + if firstByte&0x80 == 0 { + return decIntZigZag(uint64(firstByte)), 1 + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + for i := 0; i < numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return decIntZigZag(ret), numBytes + 1 +} + +func decIntZigZag(n uint64) int64 { + return int64((n >> 1) ^ -(n & 1)) +} + +func encIntZigZag(n int64) uint64 { + return uint64((n >> 63) ^ (n << 1)) +} + +func encVints(months int32, seconds int32, nanos int64) []byte { + buf := append(encVint(int64(months)), encVint(int64(seconds))...) + return append(buf, encVint(nanos)...) +} + +func encVint(v int64) []byte { + vEnc := encIntZigZag(v) + lead0 := bits.LeadingZeros64(vEnc) + numBytes := (639 - lead0*9) >> 6 + + // It can be 1 or 0 is v ==0 + if numBytes <= 1 { + return []byte{byte(vEnc)} + } + extraBytes := numBytes - 1 + var buf = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + buf[i] = byte(vEnc) + vEnc >>= 8 + } + buf[0] |= byte(^(0xff >> uint(extraBytes))) + return buf +} + func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { if info.proto > protoVersion2 { if n > math.MaxInt32 { @@ -1285,11 +1498,17 @@ return nil, marshalErrorf("can not marshal %T into %s", value, info) } -func readCollectionSize(info CollectionType, data []byte) (size, read int) { +func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) { if info.proto > protoVersion2 { + if len(data) < 4 { + return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") + } size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) read = 4 } else { + if len(data) < 2 { + return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") + } size = int(data[0])<<8 | int(data[1]) read = 2 } @@ -1322,10 +1541,10 @@ rv.Set(reflect.Zero(t)) return nil } - if len(data) < 2 { - return unmarshalErrorf("unmarshal list: unexpected eof") + n, p, err := readCollectionSize(listInfo, data) + if err != nil { + return err } - n, p := readCollectionSize(listInfo, data) data = data[p:] if k == reflect.Array { if rv.Len() != n { @@ -1335,10 +1554,10 @@ rv.Set(reflect.MakeSlice(t, n, n)) } for i := 0; i < n; i++ { - if len(data) < 2 { - return unmarshalErrorf("unmarshal list: unexpected eof") + m, p, err := readCollectionSize(listInfo, data) + if err != nil { + return err } - m, p := readCollectionSize(listInfo, data) data = data[p:] if err := Unmarshal(listInfo.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil { return err @@ -1363,15 +1582,16 @@ } rv := reflect.ValueOf(value) - if rv.IsNil() { - return nil, nil - } t := rv.Type() if t.Kind() != reflect.Map { return nil, marshalErrorf("can not marshal %T into %s", value, info) } + if rv.IsNil() { + return nil, nil + } + buf := &bytes.Buffer{} n := rv.Len() @@ -1422,16 +1642,16 @@ return nil } rv.Set(reflect.MakeMap(t)) - if len(data) < 2 { - return unmarshalErrorf("unmarshal map: unexpected eof") + n, p, err := readCollectionSize(mapInfo, data) + if err != nil { + return err } - n, p := readCollectionSize(mapInfo, data) data = data[p:] for i := 0; i < n; i++ { - if len(data) < 2 { - return unmarshalErrorf("unmarshal list: unexpected eof") + m, p, err := readCollectionSize(mapInfo, data) + if err != nil { + return err } - m, p := readCollectionSize(mapInfo, data) data = data[p:] key := reflect.New(t.Key()) if err := Unmarshal(mapInfo.Key, data[:m], key.Interface()); err != nil { @@ -1439,7 +1659,10 @@ } data = data[m:] - m, p = readCollectionSize(mapInfo, data) + m, p, err = readCollectionSize(mapInfo, data) + if err != nil { + return err + } data = data[p:] val := reflect.New(t.Elem()) if err := Unmarshal(mapInfo.Elem, data[:m], val.Interface()); err != nil { @@ -2053,6 +2276,17 @@ Elems []TypeInfo } +func (t TupleTypeInfo) String() string { + var buf bytes.Buffer + buf.WriteString(fmt.Sprintf("%s(", t.typ)) + for _, elem := range t.Elems { + buf.WriteString(fmt.Sprintf("%s, ", elem)) + } + buf.Truncate(buf.Len() - 2) + buf.WriteByte(')') + return buf.String() +} + func (t TupleTypeInfo) New() interface{} { return reflect.New(goType(t)).Interface() } @@ -2119,6 +2353,7 @@ TypeTime Type = 0x0012 TypeSmallInt Type = 0x0013 TypeTinyInt Type = 0x0014 + TypeDuration Type = 0x0015 TypeList Type = 0x0020 TypeMap Type = 0x0021 TypeSet Type = 0x0022 @@ -2163,6 +2398,8 @@ return "inet" case TypeDate: return "date" + case TypeDuration: + return "duration" case TypeTime: return "time" case TypeSmallInt: diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/marshal_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/marshal_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/marshal_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/marshal_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,4 +1,4 @@ -// +build all unit +//+build all unit package gocql @@ -16,6 +16,11 @@ ) type AliasInt int +type AliasUint uint +type AliasUint8 uint8 +type AliasUint16 uint16 +type AliasUint32 uint32 +type AliasUint64 uint64 var marshalTests = []struct { Info TypeInfo @@ -315,6 +320,20 @@ nil, }, { + NativeType{proto: 4, typ: TypeTime}, + []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), + time.Duration(int64(1376387523000)), + nil, + nil, + }, + { + NativeType{proto: 4, typ: TypeTime}, + []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), + int64(1376387523000), + nil, + nil, + }, + { NativeType{proto: 2, typ: TypeTimestamp}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), @@ -329,6 +348,27 @@ nil, }, { + NativeType{proto: 5, typ: TypeDuration}, + []byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91\x06"), + Duration{Months: 1233, Days: 123213, Nanoseconds: 2312323}, + nil, + nil, + }, + { + NativeType{proto: 5, typ: TypeDuration}, + []byte("\x89\xa1\xc3\xc2\x99\xe0F\x91\x05"), + Duration{Months: -1233, Days: -123213, Nanoseconds: -2312323}, + nil, + nil, + }, + { + NativeType{proto: 5, typ: TypeDuration}, + []byte("\x02\x04\x80\xe6"), + Duration{Months: 1, Days: 2, Nanoseconds: 115}, + nil, + nil, + }, + { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeList}, Elem: NativeType{proto: 2, typ: TypeInt}, @@ -769,12 +809,68 @@ }, { NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\x00\xff"), + uint8(255), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), uint16(65535), nil, nil, }, { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\xff\xff"), + uint32(65535), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\xff\xff"), + uint64(65535), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\x00\xff"), + AliasUint8(255), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\xff\xff"), + AliasUint16(65535), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\xff\xff"), + AliasUint32(65535), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\xff\xff"), + AliasUint64(65535), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\xff\xff"), + AliasUint(65535), + nil, + nil, + }, + { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\x7f"), 127, // math.MaxInt8 @@ -838,6 +934,62 @@ nil, }, { + NativeType{proto: 2, typ: TypeTinyInt}, + []byte("\xff"), + AliasUint8(255), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeTinyInt}, + []byte("\xff"), + AliasUint64(255), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeTinyInt}, + []byte("\xff"), + AliasUint32(255), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeTinyInt}, + []byte("\xff"), + AliasUint16(255), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeTinyInt}, + []byte("\xff"), + AliasUint(255), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\x00\x00\x00\xff"), + uint8(math.MaxUint8), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\x00\x00\xff\xff"), + uint64(math.MaxUint16), + nil, + nil, + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\xff\xff\xff\xff"), + uint64(math.MaxUint32), + nil, + nil, + }, + { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint64(math.MaxUint64), @@ -852,6 +1004,13 @@ nil, }, { + NativeType{proto: 2, typ: TypeInt}, + []byte("\xff\xff\xff\xff"), + uint64(math.MaxUint32), + nil, + nil, + }, + { NativeType{proto: 2, typ: TypeBlob}, []byte(nil), ([]byte)(nil), @@ -877,6 +1036,182 @@ }, } +var unmarshalTests = []struct { + Info TypeInfo + Data []byte + Value interface{} + UnmarshalError error +}{ + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\xff\xff"), + uint8(0), + UnmarshalError("unmarshal int: value -1 out of range for uint8"), + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\x01\x00"), + uint8(0), + UnmarshalError("unmarshal int: value 256 out of range for uint8"), + }, + { + NativeType{proto: 2, typ: TypeInt}, + []byte("\xff\xff\xff\xff"), + uint8(0), + UnmarshalError("unmarshal int: value -1 out of range for uint8"), + }, + { + NativeType{proto: 2, typ: TypeInt}, + []byte("\x00\x00\x01\x00"), + uint8(0), + UnmarshalError("unmarshal int: value 256 out of range for uint8"), + }, + { + NativeType{proto: 2, typ: TypeInt}, + []byte("\xff\xff\xff\xff"), + uint16(0), + UnmarshalError("unmarshal int: value -1 out of range for uint16"), + }, + { + NativeType{proto: 2, typ: TypeInt}, + []byte("\x00\x01\x00\x00"), + uint16(0), + UnmarshalError("unmarshal int: value 65536 out of range for uint16"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), + uint8(0), + UnmarshalError("unmarshal int: value -1 out of range for uint8"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), + uint8(0), + UnmarshalError("unmarshal int: value 256 out of range for uint8"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), + uint8(0), + UnmarshalError("unmarshal int: value -1 out of range for uint8"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), + uint8(0), + UnmarshalError("unmarshal int: value 256 out of range for uint8"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), + uint16(0), + UnmarshalError("unmarshal int: value -1 out of range for uint16"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\x00\x01\x00\x00"), + uint16(0), + UnmarshalError("unmarshal int: value 65536 out of range for uint16"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), + uint32(0), + UnmarshalError("unmarshal int: value -1 out of range for uint32"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x01\x00\x00\x00\x00"), + uint32(0), + UnmarshalError("unmarshal int: value 4294967296 out of range for uint32"), + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\xff\xff"), + AliasUint8(0), + UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), + }, + { + NativeType{proto: 2, typ: TypeSmallInt}, + []byte("\x01\x00"), + AliasUint8(0), + UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), + }, + { + NativeType{proto: 2, typ: TypeInt}, + []byte("\xff\xff\xff\xff"), + AliasUint8(0), + UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), + }, + { + NativeType{proto: 2, typ: TypeInt}, + []byte("\x00\x00\x01\x00"), + AliasUint8(0), + UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), + }, + { + NativeType{proto: 2, typ: TypeInt}, + []byte("\xff\xff\xff\xff"), + AliasUint16(0), + UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint16"), + }, + { + NativeType{proto: 2, typ: TypeInt}, + []byte("\x00\x01\x00\x00"), + AliasUint16(0), + UnmarshalError("unmarshal int: value 65536 out of range for gocql.AliasUint16"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), + AliasUint8(0), + UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), + AliasUint8(0), + UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), + AliasUint8(0), + UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), + AliasUint8(0), + UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), + AliasUint16(0), + UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint16"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x00\x00\x01\x00\x00"), + AliasUint16(0), + UnmarshalError("unmarshal int: value 65536 out of range for gocql.AliasUint16"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), + AliasUint32(0), + UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint32"), + }, + { + NativeType{proto: 2, typ: TypeBigInt}, + []byte("\x00\x00\x00\x01\x00\x00\x00\x00"), + AliasUint32(0), + UnmarshalError("unmarshal int: value 4294967296 out of range for gocql.AliasUint32"), + }, +} + func decimalize(s string) *inf.Dec { i, _ := new(inf.Dec).SetString(s) return i @@ -920,7 +1255,24 @@ } } else { if err := Unmarshal(test.Info, test.Data, test.Value); err != test.UnmarshalError { - t.Errorf("unmarshalTest[%d] (%v=>%t): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError) + t.Errorf("unmarshalTest[%d] (%v=>%T): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError) + } + } + } + for i, test := range unmarshalTests { + v := reflect.New(reflect.TypeOf(test.Value)) + if test.UnmarshalError == nil { + err := Unmarshal(test.Info, test.Data, v.Interface()) + if err != nil { + t.Errorf("unmarshalTest[%d] (%v=>%T): %v", i, test.Info, test.Value, err) + continue + } + if !reflect.DeepEqual(v.Elem().Interface(), test.Value) { + t.Errorf("unmarshalTest[%d] (%v=>%T): expected %#v, got %#v.", i, test.Info, test.Value, test.Value, v.Elem().Interface()) + } + } else { + if err := Unmarshal(test.Info, test.Data, v.Interface()); err != test.UnmarshalError { + t.Errorf("unmarshalTest[%d] (%v=>%T): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError) } } } @@ -1196,6 +1548,46 @@ } } +func TestMarshalTime(t *testing.T) { + durationS := "1h10m10s" + duration, _ := time.ParseDuration(durationS) + expectedData := encBigInt(duration.Nanoseconds()) + var marshalTimeTests = []struct { + Info TypeInfo + Data []byte + Value interface{} + }{ + { + NativeType{proto: 4, typ: TypeTime}, + expectedData, + duration.Nanoseconds(), + }, + { + NativeType{proto: 4, typ: TypeTime}, + expectedData, + duration, + }, + { + NativeType{proto: 4, typ: TypeTime}, + expectedData, + &duration, + }, + } + + for i, test := range marshalTimeTests { + t.Log(i, test) + data, err := Marshal(test.Info, test.Value) + if err != nil { + t.Errorf("marshalTest[%d]: %v", i, err) + continue + } + if !bytes.Equal(data, test.Data) { + t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, + test.Data, decInt(test.Data), data, decInt(data), test.Value) + } + } +} + func TestMarshalTimestamp(t *testing.T) { var marshalTimestampTests = []struct { Info TypeInfo @@ -1355,10 +1747,18 @@ t.Errorf("marshalTest: expected %v, got %v", expectedDate, formattedDate) return } + var stringDate string + if err2 := unmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &stringDate); err2 != nil { + t.Fatal(err2) + } + if expectedDate != stringDate { + t.Errorf("marshalTest: expected %v, got %v", expectedDate, formattedDate) + return + } } func TestMarshalDate(t *testing.T) { - now := time.Now() + now := time.Now().UTC() timestamp := now.UnixNano() / int64(time.Millisecond) expectedData := encInt(int32(timestamp/86400000 + int64(1<<31))) var marshalDateTests = []struct { @@ -1401,3 +1801,146 @@ } } } + +func BenchmarkUnmarshalVarchar(b *testing.B) { + b.ReportAllocs() + src := make([]byte, 1024) + dst := make([]byte, len(src)) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := unmarshalVarchar(NativeType{}, src, &dst); err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalDuration(t *testing.T) { + durationS := "1h10m10s" + duration, _ := time.ParseDuration(durationS) + expectedData := append([]byte{0, 0}, encVint(duration.Nanoseconds())...) + var marshalDurationTests = []struct { + Info TypeInfo + Data []byte + Value interface{} + }{ + { + NativeType{proto: 5, typ: TypeDuration}, + expectedData, + duration.Nanoseconds(), + }, + { + NativeType{proto: 5, typ: TypeDuration}, + expectedData, + duration, + }, + { + NativeType{proto: 5, typ: TypeDuration}, + expectedData, + durationS, + }, + { + NativeType{proto: 5, typ: TypeDuration}, + expectedData, + &duration, + }, + } + + for i, test := range marshalDurationTests { + t.Log(i, test) + data, err := Marshal(test.Info, test.Value) + if err != nil { + t.Errorf("marshalTest[%d]: %v", i, err) + continue + } + if !bytes.Equal(data, test.Data) { + t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, + test.Data, decInt(test.Data), data, decInt(data), test.Value) + } + } +} + +func TestReadCollectionSize(t *testing.T) { + listV2 := CollectionType{ + NativeType: NativeType{proto: 2, typ: TypeList}, + Elem: NativeType{proto: 2, typ: TypeVarchar}, + } + listV3 := CollectionType{ + NativeType: NativeType{proto: 3, typ: TypeList}, + Elem: NativeType{proto: 3, typ: TypeVarchar}, + } + + tests := []struct { + name string + info CollectionType + data []byte + isError bool + expectedSize int + }{ + { + name: "short read 0 proto 2", + info: listV2, + data: []byte{}, + isError: true, + }, + { + name: "short read 1 proto 2", + info: listV2, + data: []byte{0x01}, + isError: true, + }, + { + name: "good read proto 2", + info: listV2, + data: []byte{0x01, 0x38}, + expectedSize: 0x0138, + }, + { + name: "short read 0 proto 3", + info: listV3, + data: []byte{}, + isError: true, + }, + { + name: "short read 1 proto 3", + info: listV3, + data: []byte{0x01}, + isError: true, + }, + { + name: "short read 2 proto 3", + info: listV3, + data: []byte{0x01, 0x38}, + isError: true, + }, + { + name: "short read 3 proto 3", + info: listV3, + data: []byte{0x01, 0x38, 0x42}, + isError: true, + }, + { + name: "good read proto 3", + info: listV3, + data: []byte{0x01, 0x38, 0x42, 0x22}, + expectedSize: 0x01384222, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + size, _, err := readCollectionSize(test.info, test.data) + if test.isError { + if err == nil { + t.Fatal("Expected error, but it was nil") + } + } else { + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if size != test.expectedSize { + t.Fatalf("Expected size of %d, but got %d", test.expectedSize, size) + } + } + }) + } +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/metadata.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/metadata.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/metadata.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/metadata.go 2019-11-02 13:15:23.000000000 +0000 @@ -20,6 +20,9 @@ StrategyClass string StrategyOptions map[string]interface{} Tables map[string]*TableMetadata + Functions map[string]*FunctionMetadata + Aggregates map[string]*AggregateMetadata + Views map[string]*ViewMetadata } // schema metadata for a table (a.k.a. column family) @@ -52,6 +55,41 @@ Index ColumnIndexMetadata } +// FunctionMetadata holds metadata for function constructs +type FunctionMetadata struct { + Keyspace string + Name string + ArgumentTypes []TypeInfo + ArgumentNames []string + Body string + CalledOnNullInput bool + Language string + ReturnType TypeInfo +} + +// AggregateMetadata holds metadata for aggregate constructs +type AggregateMetadata struct { + Keyspace string + Name string + ArgumentTypes []TypeInfo + FinalFunc FunctionMetadata + InitCond string + ReturnType TypeInfo + StateFunc FunctionMetadata + StateType TypeInfo + + stateFunc string + finalFunc string +} + +// ViewMetadata holds the metadata for views. +type ViewMetadata struct { + Keyspace string + Name string + FieldNames []string + FieldTypes []TypeInfo +} + // the ordering of the column with regard to its comparator type ColumnOrder bool @@ -196,9 +234,21 @@ if err != nil { return err } + functions, err := getFunctionsMetadata(s.session, keyspaceName) + if err != nil { + return err + } + aggregates, err := getAggregatesMetadata(s.session, keyspaceName) + if err != nil { + return err + } + views, err := getViewsMetadata(s.session, keyspaceName) + if err != nil { + return err + } // organize the schema data - compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns) + compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns, functions, aggregates, views) // update the cache s.cache[keyspaceName] = keyspace @@ -216,6 +266,9 @@ keyspace *KeyspaceMetadata, tables []TableMetadata, columns []ColumnMetadata, + functions []FunctionMetadata, + aggregates []AggregateMetadata, + views []ViewMetadata, ) { keyspace.Tables = make(map[string]*TableMetadata) for i := range tables { @@ -223,28 +276,51 @@ keyspace.Tables[tables[i].Name] = &tables[i] } + keyspace.Functions = make(map[string]*FunctionMetadata, len(functions)) + for i := range functions { + keyspace.Functions[functions[i].Name] = &functions[i] + } + keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates)) + for _, aggregate := range aggregates { + aggregate.FinalFunc = *keyspace.Functions[aggregate.finalFunc] + aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc] + keyspace.Aggregates[aggregate.Name] = &aggregate + } + keyspace.Views = make(map[string]*ViewMetadata, len(views)) + for i := range views { + keyspace.Views[views[i].Name] = &views[i] + } // add columns from the schema data for i := range columns { + col := &columns[i] // decode the validator for TypeInfo and order - if columns[i].ClusteringOrder != "" { // Cassandra 3.x+ - columns[i].Type = NativeType{typ: getCassandraType(columns[i].Validator)} - columns[i].Order = ASC - if columns[i].ClusteringOrder == "desc" { - columns[i].Order = DESC + if col.ClusteringOrder != "" { // Cassandra 3.x+ + col.Type = getCassandraType(col.Validator) + col.Order = ASC + if col.ClusteringOrder == "desc" { + col.Order = DESC } } else { - validatorParsed := parseType(columns[i].Validator) - columns[i].Type = validatorParsed.types[0] - columns[i].Order = ASC + validatorParsed := parseType(col.Validator) + col.Type = validatorParsed.types[0] + col.Order = ASC if validatorParsed.reversed[0] { - columns[i].Order = DESC + col.Order = DESC } } - table := keyspace.Tables[columns[i].Table] - table.Columns[columns[i].Name] = &columns[i] - table.OrderedColumns = append(table.OrderedColumns, columns[i].Name) + table, ok := keyspace.Tables[col.Table] + if !ok { + // if the schema is being updated we will race between seeing + // the metadata be complete. Potentially we should check for + // schema versions before and after reading the metadata and + // if they dont match try again. + continue + } + + table.Columns[col.Name] = col + table.OrderedColumns = append(table.OrderedColumns, col.Name) } if protoVersion == protoVersion1 { @@ -426,8 +502,9 @@ } keyspace.StrategyClass = replication["class"] + delete(replication, "class") - keyspace.StrategyOptions = make(map[string]interface{}) + keyspace.StrategyOptions = make(map[string]interface{}, len(replication)) for k, v := range replication { keyspace.StrategyOptions[k] = v } @@ -783,6 +860,171 @@ return columns, nil } +func getTypeInfo(t string) TypeInfo { + if strings.HasPrefix(t, apacheCassandraTypePrefix) { + t = apacheToCassandraType(t) + } + return getCassandraType(t) +} + +func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, error) { + if session.cfg.ProtoVersion == protoVersion1 { + return nil, nil + } + var tableName string + if session.useSystemSchema { + tableName = "system_schema.types" + } else { + tableName = "system.schema_usertypes" + } + stmt := fmt.Sprintf(` + SELECT + type_name, + field_names, + field_types + FROM %s + WHERE keyspace_name = ?`, tableName) + + var views []ViewMetadata + + rows := session.control.query(stmt, keyspaceName).Scanner() + for rows.Next() { + view := ViewMetadata{Keyspace: keyspaceName} + var argumentTypes []string + err := rows.Scan(&view.Name, + &view.FieldNames, + &argumentTypes, + ) + if err != nil { + return nil, err + } + view.FieldTypes = make([]TypeInfo, len(argumentTypes)) + for i, argumentType := range argumentTypes { + view.FieldTypes[i] = getTypeInfo(argumentType) + } + views = append(views, view) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return views, nil +} + +func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMetadata, error) { + if session.cfg.ProtoVersion == protoVersion1 || !session.hasAggregatesAndFunctions { + return nil, nil + } + var tableName string + if session.useSystemSchema { + tableName = "system_schema.functions" + } else { + tableName = "system.schema_functions" + } + stmt := fmt.Sprintf(` + SELECT + function_name, + argument_types, + argument_names, + body, + called_on_null_input, + language, + return_type + FROM %s + WHERE keyspace_name = ?`, tableName) + + var functions []FunctionMetadata + + rows := session.control.query(stmt, keyspaceName).Scanner() + for rows.Next() { + function := FunctionMetadata{Keyspace: keyspaceName} + var argumentTypes []string + var returnType string + err := rows.Scan(&function.Name, + &argumentTypes, + &function.ArgumentNames, + &function.Body, + &function.CalledOnNullInput, + &function.Language, + &returnType, + ) + if err != nil { + return nil, err + } + function.ReturnType = getTypeInfo(returnType) + function.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) + for i, argumentType := range argumentTypes { + function.ArgumentTypes[i] = getTypeInfo(argumentType) + } + functions = append(functions, function) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return functions, nil +} + +func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMetadata, error) { + if session.cfg.ProtoVersion == protoVersion1 || !session.hasAggregatesAndFunctions { + return nil, nil + } + var tableName string + if session.useSystemSchema { + tableName = "system_schema.aggregates" + } else { + tableName = "system.schema_aggregates" + } + + stmt := fmt.Sprintf(` + SELECT + aggregate_name, + argument_types, + final_func, + initcond, + return_type, + state_func, + state_type + FROM %s + WHERE keyspace_name = ?`, tableName) + + var aggregates []AggregateMetadata + + rows := session.control.query(stmt, keyspaceName).Scanner() + for rows.Next() { + aggregate := AggregateMetadata{Keyspace: keyspaceName} + var argumentTypes []string + var returnType string + var stateType string + err := rows.Scan(&aggregate.Name, + &argumentTypes, + &aggregate.finalFunc, + &aggregate.InitCond, + &returnType, + &aggregate.stateFunc, + &stateType, + ) + if err != nil { + return nil, err + } + aggregate.ReturnType = getTypeInfo(returnType) + aggregate.StateType = getTypeInfo(stateType) + aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) + for i, argumentType := range argumentTypes { + aggregate.ArgumentTypes[i] = getTypeInfo(argumentType) + } + aggregates = append(aggregates, aggregate) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return aggregates, nil +} + // type definition parser state type typeParser struct { input string diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/metadata_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/metadata_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/metadata_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/metadata_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -94,7 +94,7 @@ {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "schema_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType"}, {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "tokens", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)"}, } - compileMetadata(1, keyspace, tables, columns) + compileMetadata(1, keyspace, tables, columns, nil, nil, nil) assertKeyspaceMetadata( t, keyspace, @@ -375,7 +375,7 @@ Validator: "org.apache.cassandra.db.marshal.UTF8Type", }, } - compileMetadata(2, keyspace, tables, columns) + compileMetadata(2, keyspace, tables, columns, nil, nil, nil) assertKeyspaceMetadata( t, keyspace, diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/policies.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/policies.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/policies.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/policies.go 2019-11-02 13:15:23.000000000 +0000 @@ -5,6 +5,10 @@ package gocql import ( + "context" + crand "crypto/rand" + "encoding/binary" + "errors" "fmt" "math" "math/rand" @@ -117,7 +121,7 @@ return false } - newL = newL[:size-1 : size-1] + newL = newL[: size-1 : size-1] c.list.Store(&newL) c.mu.Unlock() @@ -128,9 +132,24 @@ // exposes the correct functions for the retry policy logic to evaluate correctly. type RetryableQuery interface { Attempts() int + SetConsistency(c Consistency) GetConsistency() Consistency + Context() context.Context } +type RetryType uint16 + +const ( + Retry RetryType = 0x00 // retry on same connection + RetryNextHost RetryType = 0x01 // retry on another connection + Ignore RetryType = 0x02 // ignore error and return result + Rethrow RetryType = 0x03 // raise error and stop retrying +) + +// ErrUnknownRetryType is returned if the retry policy returns a retry type +// unknown to the query executor. +var ErrUnknownRetryType = errors.New("unknown retry type returned by retry policy") + // RetryPolicy interface is used by gocql to determine if a query can be attempted // again after a retryable error has been received. The interface allows gocql // users to implement their own logic to determine if a query can be attempted @@ -140,6 +159,7 @@ // interface. type RetryPolicy interface { Attempt(RetryableQuery) bool + GetRetryType(error) RetryType } // SimpleRetryPolicy has simple logic for attempting a query a fixed number of times. @@ -162,6 +182,10 @@ return q.Attempts() <= s.NumRetries } +func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType { + return RetryNextHost +} + // ExponentialBackoffRetryPolicy sleeps between attempts type ExponentialBackoffRetryPolicy struct { NumRetries int @@ -176,23 +200,92 @@ return true } -func (e *ExponentialBackoffRetryPolicy) napTime(attempts int) time.Duration { - if e.Min <= 0 { - e.Min = 100 * time.Millisecond +// used to calculate exponentially growing time +func getExponentialTime(min time.Duration, max time.Duration, attempts int) time.Duration { + if min <= 0 { + min = 100 * time.Millisecond } - if e.Max <= 0 { - e.Max = 10 * time.Second + if max <= 0 { + max = 10 * time.Second } - minFloat := float64(e.Min) + minFloat := float64(min) napDuration := minFloat * math.Pow(2, float64(attempts-1)) // add some jitter napDuration += rand.Float64()*minFloat - (minFloat / 2) - if napDuration > float64(e.Max) { - return time.Duration(e.Max) + if napDuration > float64(max) { + return time.Duration(max) } return time.Duration(napDuration) } +func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType { + return RetryNextHost +} + +// DowngradingConsistencyRetryPolicy: Next retry will be with the next consistency level +// provided in the slice +// +// On a read timeout: the operation is retried with the next provided consistency +// level. +// +// On a write timeout: if the operation is an :attr:`~.UNLOGGED_BATCH` +// and at least one replica acknowledged the write, the operation is +// retried with the next consistency level. Furthermore, for other +// write types, if at least one replica acknowledged the write, the +// timeout is ignored. +// +// On an unavailable exception: if at least one replica is alive, the +// operation is retried with the next provided consistency level. + +type DowngradingConsistencyRetryPolicy struct { + ConsistencyLevelsToTry []Consistency +} + +func (d *DowngradingConsistencyRetryPolicy) Attempt(q RetryableQuery) bool { + currentAttempt := q.Attempts() + + if currentAttempt > len(d.ConsistencyLevelsToTry) { + return false + } else if currentAttempt > 0 { + q.SetConsistency(d.ConsistencyLevelsToTry[currentAttempt-1]) + if gocqlDebug { + Logger.Printf("%T: set consistency to %q\n", + d, + d.ConsistencyLevelsToTry[currentAttempt-1]) + } + } + return true +} + +func (d *DowngradingConsistencyRetryPolicy) GetRetryType(err error) RetryType { + switch t := err.(type) { + case *RequestErrUnavailable: + if t.Alive > 0 { + return Retry + } + return Rethrow + case *RequestErrWriteTimeout: + if t.WriteType == "SIMPLE" || t.WriteType == "BATCH" || t.WriteType == "COUNTER" { + if t.Received > 0 { + return Ignore + } + return Rethrow + } + if t.WriteType == "UNLOGGED_BATCH" { + return Retry + } + return Rethrow + case *RequestErrReadTimeout: + return Retry + default: + return RetryNextHost + } +} + +func (e *ExponentialBackoffRetryPolicy) napTime(attempts int) time.Duration { + return getExponentialTime(e.Min, e.Max, attempts) +} + type HostStateNotifier interface { AddHost(host *HostInfo) RemoveHost(host *HostInfo) @@ -200,11 +293,19 @@ HostDown(host *HostInfo) } +type KeyspaceUpdateEvent struct { + Keyspace string + Change string +} + // HostSelectionPolicy is an interface for selecting // the most appropriate host to execute a given query. type HostSelectionPolicy interface { HostStateNotifier SetPartitioner + KeyspaceChanged(KeyspaceUpdateEvent) + Init(*Session) + IsLocal(host *HostInfo) bool //Pick returns an iteration function over selected hosts Pick(ExecutableQuery) NextHost } @@ -235,34 +336,24 @@ type roundRobinHostPolicy struct { hosts cowHostList - pos uint32 - mu sync.RWMutex } -func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) { - // noop -} +func (r *roundRobinHostPolicy) IsLocal(*HostInfo) bool { return true } +func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} +func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {} +func (r *roundRobinHostPolicy) Init(*Session) {} func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost { - // i is used to limit the number of attempts to find a host - // to the number of hosts known to this policy - var i int - return func() SelectedHost { - hosts := r.hosts.get() - if len(hosts) == 0 { - return nil - } + src := r.hosts.get() + hosts := make([]*HostInfo, len(src)) + copy(hosts, src) + + rand := rand.New(randSource()) + rand.Shuffle(len(hosts), func(i, j int) { + hosts[i], hosts[j] = hosts[j], hosts[i] + }) - // always increment pos to evenly distribute traffic in case of - // failures - pos := atomic.AddUint32(&r.pos, 1) - 1 - if i >= len(hosts) { - return nil - } - host := hosts[(pos)%uint32(len(hosts))] - i++ - return (*selectedHost)(host) - } + return roundRobbin(hosts) } func (r *roundRobinHostPolicy) AddHost(host *HostInfo) { @@ -281,71 +372,208 @@ r.RemoveHost(host) } +func ShuffleReplicas() func(*tokenAwareHostPolicy) { + return func(t *tokenAwareHostPolicy) { + t.shuffleReplicas = true + } +} + +// NonLocalReplicasFallback enables fallback to replicas that are not considered local. +// +// TokenAwareHostPolicy used with DCAwareHostPolicy fallback first selects replicas by partition key in local DC, then +// falls back to other nodes in the local DC. Enabling NonLocalReplicasFallback causes TokenAwareHostPolicy +// to first select replicas by partition key in local DC, then replicas by partition key in remote DCs and fall back +// to other nodes in local DC. +func NonLocalReplicasFallback() func(policy *tokenAwareHostPolicy) { + return func(t *tokenAwareHostPolicy) { + t.nonLocalReplicasFallback = true + } +} + // TokenAwareHostPolicy is a token aware host selection policy, where hosts are // selected based on the partition key, so queries are sent to the host which // owns the partition. Fallback is used when routing information is not available. -func TokenAwareHostPolicy(fallback HostSelectionPolicy) HostSelectionPolicy { - return &tokenAwareHostPolicy{fallback: fallback} +func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAwareHostPolicy)) HostSelectionPolicy { + p := &tokenAwareHostPolicy{fallback: fallback} + for _, opt := range opts { + opt(p) + } + return p +} + +// clusterMeta holds metadata about cluster topology. +// It is used inside atomic.Value and shallow copies are used when replacing it, +// so fields should not be modified in-place. Instead, to modify a field a copy of the field should be made +// and the pointer in clusterMeta updated to point to the new value. +type clusterMeta struct { + // replicas is map[keyspace]map[token]hosts + replicas map[string]tokenRingReplicas + tokenRing *tokenRing } type tokenAwareHostPolicy struct { + fallback HostSelectionPolicy + getKeyspaceMetadata func(keyspace string) (*KeyspaceMetadata, error) + getKeyspaceName func() string + + shuffleReplicas bool + nonLocalReplicasFallback bool + + // mu protects writes to hosts, partitioner, metadata. + // reads can be unlocked as long as they are not used for updating state later. + mu sync.Mutex hosts cowHostList - mu sync.RWMutex partitioner string - tokenRing *tokenRing - fallback HostSelectionPolicy + metadata atomic.Value // *clusterMeta +} + +func (t *tokenAwareHostPolicy) Init(s *Session) { + t.getKeyspaceMetadata = s.KeyspaceMetadata + t.getKeyspaceName = func() string { return s.cfg.Keyspace } +} + +func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool { + return t.fallback.IsLocal(host) +} + +func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) { + t.mu.Lock() + defer t.mu.Unlock() + meta := t.getMetadataForUpdate() + t.updateReplicas(meta, update.Keyspace) + t.metadata.Store(meta) +} + +// updateReplicas updates replicas in clusterMeta. +// It must be called with t.mu mutex locked. +// meta must not be nil and it's replicas field will be updated. +func (t *tokenAwareHostPolicy) updateReplicas(meta *clusterMeta, keyspace string) { + newReplicas := make(map[string]tokenRingReplicas, len(meta.replicas)) + + ks, err := t.getKeyspaceMetadata(keyspace) + if err == nil { + strat := getStrategy(ks) + if strat != nil { + if meta != nil && meta.tokenRing != nil { + newReplicas[keyspace] = strat.replicaMap(meta.tokenRing) + } + } + } + + for ks, replicas := range meta.replicas { + if ks != keyspace { + newReplicas[ks] = replicas + } + } + + meta.replicas = newReplicas } func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) { + t.mu.Lock() + defer t.mu.Unlock() + if t.partitioner != partitioner { t.fallback.SetPartitioner(partitioner) t.partitioner = partitioner - - t.resetTokenRing() + meta := t.getMetadataForUpdate() + meta.resetTokenRing(t.partitioner, t.hosts.get()) + t.updateReplicas(meta, t.getKeyspaceName()) + t.metadata.Store(meta) } } func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) { - t.hosts.add(host) + t.mu.Lock() + if t.hosts.add(host) { + meta := t.getMetadataForUpdate() + meta.resetTokenRing(t.partitioner, t.hosts.get()) + t.updateReplicas(meta, t.getKeyspaceName()) + t.metadata.Store(meta) + } + t.mu.Unlock() + t.fallback.AddHost(host) +} + +func (t *tokenAwareHostPolicy) AddHosts(hosts []*HostInfo) { + t.mu.Lock() + + for _, host := range hosts { + t.hosts.add(host) + } - t.resetTokenRing() + meta := t.getMetadataForUpdate() + meta.resetTokenRing(t.partitioner, t.hosts.get()) + t.updateReplicas(meta, t.getKeyspaceName()) + t.metadata.Store(meta) + + t.mu.Unlock() + + for _, host := range hosts { + t.fallback.AddHost(host) + } } func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) { - t.hosts.remove(host.ConnectAddress()) - t.fallback.RemoveHost(host) + t.mu.Lock() + if t.hosts.remove(host.ConnectAddress()) { + meta := t.getMetadataForUpdate() + meta.resetTokenRing(t.partitioner, t.hosts.get()) + t.updateReplicas(meta, t.getKeyspaceName()) + t.metadata.Store(meta) + } + t.mu.Unlock() - t.resetTokenRing() + t.fallback.RemoveHost(host) } func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) { - t.AddHost(host) + t.fallback.HostUp(host) } func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) { - t.RemoveHost(host) + t.fallback.HostDown(host) } -func (t *tokenAwareHostPolicy) resetTokenRing() { - t.mu.Lock() - defer t.mu.Unlock() - - if t.partitioner == "" { +// getMetadataReadOnly returns current cluster metadata. +// Metadata uses copy on write, so the returned value should be only used for reading. +// To obtain a copy that could be updated, use getMetadataForUpdate instead. +func (t *tokenAwareHostPolicy) getMetadataReadOnly() *clusterMeta { + meta, _ := t.metadata.Load().(*clusterMeta) + return meta +} + +// getMetadataForUpdate returns clusterMeta suitable for updating. +// It is a SHALLOW copy of current metadata in case it was already set or new empty clusterMeta otherwise. +// This function should be called with t.mu mutex locked and the mutex should not be released before +// storing the new metadata. +func (t *tokenAwareHostPolicy) getMetadataForUpdate() *clusterMeta { + metaReadOnly := t.getMetadataReadOnly() + meta := new(clusterMeta) + if metaReadOnly != nil { + *meta = *metaReadOnly + } + return meta +} + +// resetTokenRing creates a new tokenRing. +// It must be called with t.mu locked. +func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo) { + if partitioner == "" { // partitioner not yet set return } // create a new token ring - hosts := t.hosts.get() - tokenRing, err := newTokenRing(t.partitioner, hosts) + tokenRing, err := newTokenRing(partitioner, hosts) if err != nil { Logger.Printf("Unable to update the token ring due to error: %s", err) return } // replace the token ring - t.tokenRing = tokenRing + m.tokenRing = tokenRing } func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { @@ -356,45 +584,77 @@ routingKey, err := qry.GetRoutingKey() if err != nil { return t.fallback.Pick(qry) + } else if routingKey == nil { + return t.fallback.Pick(qry) } - if routingKey == nil { + + meta := t.getMetadataReadOnly() + if meta == nil || meta.tokenRing == nil { return t.fallback.Pick(qry) } - t.mu.RLock() - // TODO retrieve a list of hosts based on the replication strategy - host := t.tokenRing.GetHostForPartitionKey(routingKey) - t.mu.RUnlock() + token := meta.tokenRing.partitioner.Hash(routingKey) + ht := meta.replicas[qry.Keyspace()].replicasFor(token) - if host == nil { - return t.fallback.Pick(qry) + var replicas []*HostInfo + if ht == nil { + host, _ := meta.tokenRing.GetHostForToken(token) + replicas = []*HostInfo{host} + } else if t.shuffleReplicas { + replicas = shuffleHosts(replicas) + } else { + replicas = ht.hosts } - // scope these variables for the same lifetime as the iterator function var ( - hostReturned bool fallbackIter NextHost + i, j int + remote []*HostInfo ) + used := make(map[*HostInfo]bool, len(replicas)) return func() SelectedHost { - if !hostReturned { - hostReturned = true - return (*selectedHost)(host) + for i < len(replicas) { + h := replicas[i] + i++ + + if !t.fallback.IsLocal(h) { + remote = append(remote, h) + continue + } + + if h.IsUp() { + used[h] = true + return (*selectedHost)(h) + } + } + + if t.nonLocalReplicasFallback { + for j < len(remote) { + h := remote[j] + j++ + + if h.IsUp() { + used[h] = true + return (*selectedHost)(h) + } + } } - // fallback if fallbackIter == nil { + // fallback fallbackIter = t.fallback.Pick(qry) } - fallbackHost := fallbackIter() - // filter the token aware selected hosts from the fallback hosts - if fallbackHost != nil && fallbackHost.Info() == host { - fallbackHost = fallbackIter() + for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() { + if !used[fallbackHost.Info()] { + used[fallbackHost.Info()] = true + return fallbackHost + } } - return fallbackHost + return nil } } @@ -422,6 +682,11 @@ hostMap map[string]*HostInfo } +func (r *hostPoolHostPolicy) Init(*Session) {} +func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} +func (r *hostPoolHostPolicy) SetPartitioner(string) {} +func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true } + func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) { peers := make([]string, len(hosts)) hostMap := make(map[string]*HostInfo, len(hosts)) @@ -486,10 +751,6 @@ r.RemoveHost(host) } -func (r *hostPoolHostPolicy) SetPartitioner(partitioner string) { - // noop -} - func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost { return func() SelectedHost { r.mu.RLock() @@ -541,8 +802,6 @@ type dcAwareRR struct { local string - pos uint32 - mu sync.RWMutex localHosts cowHostList remoteHosts cowHostList } @@ -551,13 +810,19 @@ // return hosts which are in the local datacentre before returning hosts in all // other datercentres func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy { - return &dcAwareRR{ - local: localDC, - } + return &dcAwareRR{local: localDC} +} + +func (d *dcAwareRR) Init(*Session) {} +func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {} +func (d *dcAwareRR) SetPartitioner(p string) {} + +func (d *dcAwareRR) IsLocal(host *HostInfo) bool { + return host.DataCenter() == d.local } func (d *dcAwareRR) AddHost(host *HostInfo) { - if host.DataCenter() == d.local { + if d.IsLocal(host) { d.localHosts.add(host) } else { d.remoteHosts.add(host) @@ -565,46 +830,142 @@ } func (d *dcAwareRR) RemoveHost(host *HostInfo) { - if host.DataCenter() == d.local { + if d.IsLocal(host) { d.localHosts.remove(host.ConnectAddress()) } else { d.remoteHosts.remove(host.ConnectAddress()) } } -func (d *dcAwareRR) HostUp(host *HostInfo) { - d.AddHost(host) -} +func (d *dcAwareRR) HostUp(host *HostInfo) { d.AddHost(host) } +func (d *dcAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) } + +var randSeed int64 -func (d *dcAwareRR) HostDown(host *HostInfo) { - d.RemoveHost(host) +func init() { + p := make([]byte, 8) + if _, err := crand.Read(p); err != nil { + panic(err) + } + randSeed = int64(binary.BigEndian.Uint64(p)) } -func (d *dcAwareRR) SetPartitioner(p string) {} +func randSource() rand.Source { + return rand.NewSource(atomic.AddInt64(&randSeed, 1)) +} -func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { +func roundRobbin(hosts []*HostInfo) NextHost { var i int return func() SelectedHost { - var hosts []*HostInfo - localHosts := d.localHosts.get() - remoteHosts := d.remoteHosts.get() - if len(localHosts) != 0 { - hosts = localHosts - } else { - hosts = remoteHosts - } - if len(hosts) == 0 { - return nil - } + for i < len(hosts) { + h := hosts[i] + i++ - // always increment pos to evenly distribute traffic in case of - // failures - pos := atomic.AddUint32(&d.pos, 1) - 1 - if i >= len(localHosts)+len(remoteHosts) { - return nil + if h.IsUp() { + return (*selectedHost)(h) + } } - host := hosts[(pos)%uint32(len(hosts))] - i++ - return (*selectedHost)(host) + + return nil } } + +func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { + local := d.localHosts.get() + remote := d.remoteHosts.get() + + hosts := make([]*HostInfo, len(local)+len(remote)) + n := copy(hosts, local) + copy(hosts[n:], remote) + + // TODO: use random chose-2 but that will require plumbing information + // about connection/host load to here + r := rand.New(randSource()) + for _, l := range [][]*HostInfo{hosts[:len(local)], hosts[len(local):]} { + r.Shuffle(len(l), func(i, j int) { + l[i], l[j] = l[j], l[i] + }) + } + + return roundRobbin(hosts) +} + +// ConvictionPolicy interface is used by gocql to determine if a host should be +// marked as DOWN based on the error and host info +type ConvictionPolicy interface { + // Implementations should return `true` if the host should be convicted, `false` otherwise. + AddFailure(error error, host *HostInfo) bool + //Implementations should clear out any convictions or state regarding the host. + Reset(host *HostInfo) +} + +// SimpleConvictionPolicy implements a ConvictionPolicy which convicts all hosts +// regardless of error +type SimpleConvictionPolicy struct { +} + +func (e *SimpleConvictionPolicy) AddFailure(error error, host *HostInfo) bool { + return true +} + +func (e *SimpleConvictionPolicy) Reset(host *HostInfo) {} + +// ReconnectionPolicy interface is used by gocql to determine if reconnection +// can be attempted after connection error. The interface allows gocql users +// to implement their own logic to determine how to attempt reconnection. +// +type ReconnectionPolicy interface { + GetInterval(currentRetry int) time.Duration + GetMaxRetries() int +} + +// ConstantReconnectionPolicy has simple logic for returning a fixed reconnection interval. +// +// Examples of usage: +// +// cluster.ReconnectionPolicy = &gocql.ConstantReconnectionPolicy{MaxRetries: 10, Interval: 8 * time.Second} +// +type ConstantReconnectionPolicy struct { + MaxRetries int + Interval time.Duration +} + +func (c *ConstantReconnectionPolicy) GetInterval(currentRetry int) time.Duration { + return c.Interval +} + +func (c *ConstantReconnectionPolicy) GetMaxRetries() int { + return c.MaxRetries +} + +// ExponentialReconnectionPolicy returns a growing reconnection interval. +type ExponentialReconnectionPolicy struct { + MaxRetries int + InitialInterval time.Duration +} + +func (e *ExponentialReconnectionPolicy) GetInterval(currentRetry int) time.Duration { + return getExponentialTime(e.InitialInterval, math.MaxInt16*time.Second, e.GetMaxRetries()) +} + +func (e *ExponentialReconnectionPolicy) GetMaxRetries() int { + return e.MaxRetries +} + +type SpeculativeExecutionPolicy interface { + Attempts() int + Delay() time.Duration +} + +type NonSpeculativeExecution struct{} + +func (sp NonSpeculativeExecution) Attempts() int { return 0 } // No additional attempts +func (sp NonSpeculativeExecution) Delay() time.Duration { return 1 } // The delay. Must be positive to be used in a ticker. + +type SimpleSpeculativeExecution struct { + NumAttempts int + TimeoutDelay time.Duration +} + +func (sp *SimpleSpeculativeExecution) Attempts() int { return sp.NumAttempts } +func (sp *SimpleSpeculativeExecution) Delay() time.Duration { return sp.TimeoutDelay } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/policies_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/policies_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/policies_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/policies_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -5,6 +5,7 @@ package gocql import ( + "errors" "fmt" "net" "testing" @@ -14,7 +15,7 @@ ) // Tests of the round-robin host selection policy implementation -func TestRoundRobinHostPolicy(t *testing.T) { +func TestRoundRobbin(t *testing.T) { policy := RoundRobinHostPolicy() hosts := [...]*HostInfo{ @@ -26,37 +27,33 @@ policy.AddHost(host) } - // interleaved iteration should always increment the host - iterA := policy.Pick(nil) - if actual := iterA(); actual.Info() != hosts[0] { - t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID()) - } - iterB := policy.Pick(nil) - if actual := iterB(); actual.Info() != hosts[1] { - t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID()) - } - if actual := iterB(); actual.Info() != hosts[0] { - t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID()) - } - if actual := iterA(); actual.Info() != hosts[1] { - t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID()) - } - - iterC := policy.Pick(nil) - if actual := iterC(); actual.Info() != hosts[0] { - t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID()) + got := make(map[string]bool) + it := policy.Pick(nil) + for h := it(); h != nil; h = it() { + id := h.Info().hostId + if got[id] { + t.Fatalf("got duplicate host: %v", id) + } + got[id] = true } - if actual := iterC(); actual.Info() != hosts[1] { - t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID()) + if len(got) != len(hosts) { + t.Fatalf("expected %d hosts got %d", len(hosts), len(got)) } } // Tests of the token-aware host selection policy implementation with a // round-robin host selection policy fallback. -func TestTokenAwareHostPolicy(t *testing.T) { +func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) { + const keyspace = "myKeyspace" policy := TokenAwareHostPolicy(RoundRobinHostPolicy()) + policyInternal := policy.(*tokenAwareHostPolicy) + policyInternal.getKeyspaceName = func() string { return keyspace } + policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { + return nil, errors.New("not initalized") + } query := &Query{} + query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) if iter == nil { @@ -69,48 +66,52 @@ // set the hosts hosts := [...]*HostInfo{ - {connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"00"}}, - {connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"25"}}, - {connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"50"}}, - {connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"75"}}, + {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"00"}}, + {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"25"}}, + {hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"50"}}, + {hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"75"}}, } - for _, host := range hosts { + for _, host := range &hosts { policy.AddHost(host) } - // the token ring is not setup without the partitioner, but the fallback - // should work - if actual := policy.Pick(nil)(); !actual.Info().ConnectAddress().Equal(hosts[0].ConnectAddress()) { - t.Errorf("Expected peer 0 but was %s", actual.Info().ConnectAddress()) - } - - query.RoutingKey([]byte("30")) - if actual := policy.Pick(query)(); !actual.Info().ConnectAddress().Equal(hosts[1].ConnectAddress()) { - t.Errorf("Expected peer 1 but was %s", actual.Info().ConnectAddress()) - } - policy.SetPartitioner("OrderedPartitioner") + policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) { + if keyspaceName != keyspace { + return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName) + } + return &KeyspaceMetadata{ + Name: keyspace, + StrategyClass: "SimpleStrategy", + StrategyOptions: map[string]interface{}{ + "class": "SimpleStrategy", + "replication_factor": 2, + }, + }, nil + } + policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace}) + + // The SimpleStrategy above should generate the following replicas. + // It's handy to have as reference here. + assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ + "myKeyspace": { + {orderedToken("00"), []*HostInfo{hosts[0], hosts[1]}}, + {orderedToken("25"), []*HostInfo{hosts[1], hosts[2]}}, + {orderedToken("50"), []*HostInfo{hosts[2], hosts[3]}}, + {orderedToken("75"), []*HostInfo{hosts[3], hosts[0]}}, + }, + }, policyInternal.getMetadataReadOnly().replicas) + // now the token ring is configured query.RoutingKey([]byte("20")) iter = policy.Pick(query) - if actual := iter(); !actual.Info().ConnectAddress().Equal(hosts[1].ConnectAddress()) { - t.Errorf("Expected peer 1 but was %s", actual.Info().ConnectAddress()) - } - // rest are round robin - if actual := iter(); !actual.Info().ConnectAddress().Equal(hosts[2].ConnectAddress()) { - t.Errorf("Expected peer 2 but was %s", actual.Info().ConnectAddress()) - } - if actual := iter(); !actual.Info().ConnectAddress().Equal(hosts[3].ConnectAddress()) { - t.Errorf("Expected peer 3 but was %s", actual.Info().ConnectAddress()) - } - if actual := iter(); !actual.Info().ConnectAddress().Equal(hosts[0].ConnectAddress()) { - t.Errorf("Expected peer 0 but was %s", actual.Info().ConnectAddress()) - } + iterCheck(t, iter, "0") + iterCheck(t, iter, "1") } // Tests of the host pool host selection policy implementation -func TestHostPoolHostPolicy(t *testing.T) { +func TestHostPolicy_HostPool(t *testing.T) { policy := HostPoolHostPolicy(hostpool.New(nil)) hosts := []*HostInfo{ @@ -150,7 +151,7 @@ actualD.Mark(nil) } -func TestRoundRobinNilHostInfo(t *testing.T) { +func TestHostPolicy_RoundRobin_NilHostInfo(t *testing.T) { policy := RoundRobinHostPolicy() host := &HostInfo{hostId: "host-1"} @@ -175,8 +176,13 @@ } } -func TestTokenAwareNilHostInfo(t *testing.T) { +func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) { policy := TokenAwareHostPolicy(RoundRobinHostPolicy()) + policyInternal := policy.(*tokenAwareHostPolicy) + policyInternal.getKeyspaceName = func() string { return "myKeyspace" } + policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { + return nil, errors.New("not initialized") + } hosts := [...]*HostInfo{ {connectAddress: net.IPv4(10, 0, 0, 0), tokens: []string{"00"}}, @@ -190,6 +196,7 @@ policy.SetPartitioner("OrderedPartitioner") query := &Query{} + query.getKeyspace = func() string { return "myKeyspace" } query.RoutingKey([]byte("20")) iter := policy.Pick(query) @@ -264,7 +271,7 @@ } for _, c := range cases { - q.attempts = c.attempts + q.metrics = preFilledQueryMetrics(map[string]*hostMetrics{"127.0.0.1": {Attempts: c.attempts}}) if c.allow && !rt.Attempt(q) { t.Fatalf("should allow retry after %d attempts", c.attempts) } @@ -302,7 +309,78 @@ } } -func TestDCAwareRR(t *testing.T) { +func TestDowngradingConsistencyRetryPolicy(t *testing.T) { + + q := &Query{cons: LocalQuorum} + + rewt0 := &RequestErrWriteTimeout{ + Received: 0, + WriteType: "SIMPLE", + } + + rewt1 := &RequestErrWriteTimeout{ + Received: 1, + WriteType: "BATCH", + } + + rewt2 := &RequestErrWriteTimeout{ + WriteType: "UNLOGGED_BATCH", + } + + rert := &RequestErrReadTimeout{} + + reu0 := &RequestErrUnavailable{ + Alive: 0, + } + + reu1 := &RequestErrUnavailable{ + Alive: 1, + } + + // this should allow a total of 3 tries. + consistencyLevels := []Consistency{Three, Two, One} + rt := &DowngradingConsistencyRetryPolicy{ConsistencyLevelsToTry: consistencyLevels} + cases := []struct { + attempts int + allow bool + err error + retryType RetryType + }{ + {0, true, rewt0, Rethrow}, + {3, true, rewt1, Ignore}, + {1, true, rewt2, Retry}, + {2, true, rert, Retry}, + {4, false, reu0, Rethrow}, + {16, false, reu1, Retry}, + } + + for _, c := range cases { + q.metrics = preFilledQueryMetrics(map[string]*hostMetrics{"127.0.0.1": {Attempts: c.attempts}}) + if c.retryType != rt.GetRetryType(c.err) { + t.Fatalf("retry type should be %v", c.retryType) + } + if c.allow && !rt.Attempt(q) { + t.Fatalf("should allow retry after %d attempts", c.attempts) + } + if !c.allow && rt.Attempt(q) { + t.Fatalf("should not allow retry after %d attempts", c.attempts) + } + } +} + +func iterCheck(t *testing.T, iter NextHost, hostID string) { + t.Helper() + + host := iter() + if host == nil || host.Info() == nil { + t.Fatalf("expected hostID %s got nil", hostID) + } + if host.Info().HostID() != hostID { + t.Fatalf("Expected peer %s but was %s", hostID, host.Info().HostID()) + } +} + +func TestHostPolicy_DCAwareRR(t *testing.T) { p := DCAwareRoundRobinPolicy("local") hosts := [...]*HostInfo{ @@ -316,36 +394,228 @@ p.AddHost(host) } - // interleaved iteration should always increment the host - iterA := p.Pick(nil) - if actual := iterA(); actual.Info() != hosts[0] { - t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID()) - } - iterB := p.Pick(nil) - if actual := iterB(); actual.Info() != hosts[1] { - t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID()) - } - if actual := iterB(); actual.Info() != hosts[0] { - t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID()) - } - if actual := iterA(); actual.Info() != hosts[1] { - t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID()) - } - iterC := p.Pick(nil) - if actual := iterC(); actual.Info() != hosts[0] { - t.Errorf("Expected hosts[0] but was hosts[%s]", actual.Info().HostID()) - } - p.RemoveHost(hosts[0]) - if actual := iterC(); actual.Info() != hosts[1] { - t.Errorf("Expected hosts[1] but was hosts[%s]", actual.Info().HostID()) - } - p.RemoveHost(hosts[1]) - iterD := p.Pick(nil) - if actual := iterD(); actual.Info() != hosts[2] { - t.Errorf("Expected hosts[2] but was hosts[%s]", actual.Info().HostID()) + got := make(map[string]bool, len(hosts)) + var dcs []string + + it := p.Pick(nil) + for h := it(); h != nil; h = it() { + id := h.Info().hostId + dc := h.Info().dataCenter + + if got[id] { + t.Fatalf("got duplicate host %s", id) + } + got[id] = true + dcs = append(dcs, dc) + } + + if len(got) != len(hosts) { + t.Fatalf("expected %d hosts got %d", len(hosts), len(got)) } - if actual := iterD(); actual.Info() != hosts[3] { - t.Errorf("Expected hosts[3] but was hosts[%s]", actual.Info().HostID()) + + var remote bool + for _, dc := range dcs { + if dc == "local" { + if remote { + t.Fatalf("got local dc after remote: %v", dcs) + } + } else { + remote = true + } } } + +// Tests of the token-aware host selection policy implementation with a +// DC aware round-robin host selection policy fallback +// with {"class": "NetworkTopologyStrategy", "a": 1, "b": 1, "c": 1} replication. +func TestHostPolicy_TokenAware(t *testing.T) { + const keyspace = "myKeyspace" + policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local")) + policyInternal := policy.(*tokenAwareHostPolicy) + policyInternal.getKeyspaceName = func() string { return keyspace } + policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { + return nil, errors.New("not initialized") + } + + query := &Query{} + query.getKeyspace = func() string { return keyspace } + + iter := policy.Pick(nil) + if iter == nil { + t.Fatal("host iterator was nil") + } + actual := iter() + if actual != nil { + t.Fatalf("expected nil from iterator, but was %v", actual) + } + + // set the hosts + hosts := [...]*HostInfo{ + {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"05"}, dataCenter: "remote1"}, + {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: "local"}, + {hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "remote2"}, + {hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"}, + {hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: "local"}, + {hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote2"}, + {hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "remote1"}, + {hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: "local"}, + {hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote2"}, + {hostId: "9", connectAddress: net.IPv4(10, 0, 0, 10), tokens: []string{"50"}, dataCenter: "remote1"}, + {hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: "local"}, + {hostId: "11", connectAddress: net.IPv4(10, 0, 0, 12), tokens: []string{"60"}, dataCenter: "remote2"}, + } + for _, host := range hosts { + policy.AddHost(host) + } + + // the token ring is not setup without the partitioner, but the fallback + // should work + if actual := policy.Pick(nil)(); actual == nil { + t.Fatal("expected to get host from fallback got nil") + } + + query.RoutingKey([]byte("30")) + if actual := policy.Pick(query)(); actual == nil { + t.Fatal("expected to get host from fallback got nil") + } + + policy.SetPartitioner("OrderedPartitioner") + + policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) { + if keyspaceName != keyspace { + return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName) + } + return &KeyspaceMetadata{ + Name: keyspace, + StrategyClass: "NetworkTopologyStrategy", + StrategyOptions: map[string]interface{}{ + "class": "NetworkTopologyStrategy", + "local": 1, + "remote1": 1, + "remote2": 1, + }, + }, nil + } + policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: "myKeyspace"}) + + // The NetworkTopologyStrategy above should generate the following replicas. + // It's handy to have as reference here. + assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ + "myKeyspace": { + {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2]}}, + {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3]}}, + {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4]}}, + {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5]}}, + {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6]}}, + {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7]}}, + {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8]}}, + {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9]}}, + {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10]}}, + {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11]}}, + {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0]}}, + {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1]}}, + }, + }, policyInternal.getMetadataReadOnly().replicas) + + // now the token ring is configured + query.RoutingKey([]byte("23")) + iter = policy.Pick(query) + // first should be host with matching token from the local DC + iterCheck(t, iter, "4") + // next are in non deterministic order +} + +// Tests of the token-aware host selection policy implementation with a +// DC aware round-robin host selection policy fallback +// with {"class": "NetworkTopologyStrategy", "a": 2, "b": 2, "c": 2} replication. +func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { + const keyspace = "myKeyspace" + policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local"), NonLocalReplicasFallback()) + policyInternal := policy.(*tokenAwareHostPolicy) + policyInternal.getKeyspaceName = func() string { return keyspace } + policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { + return nil, errors.New("not initialized") + } + + query := &Query{} + query.getKeyspace = func() string { return keyspace } + + iter := policy.Pick(nil) + if iter == nil { + t.Fatal("host iterator was nil") + } + actual := iter() + if actual != nil { + t.Fatalf("expected nil from iterator, but was %v", actual) + } + + // set the hosts + hosts := [...]*HostInfo{ + {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"05"}, dataCenter: "remote1"}, + {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: "local"}, + {hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "remote2"}, + {hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"}, // 1 + {hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: "local"}, // 2 + {hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote2"}, // 3 + {hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "remote1"}, // 4 + {hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: "local"}, // 5 + {hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote2"}, // 6 + {hostId: "9", connectAddress: net.IPv4(10, 0, 0, 10), tokens: []string{"50"}, dataCenter: "remote1"}, + {hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: "local"}, + {hostId: "11", connectAddress: net.IPv4(10, 0, 0, 12), tokens: []string{"60"}, dataCenter: "remote2"}, + } + for _, host := range hosts { + policy.AddHost(host) + } + + policy.SetPartitioner("OrderedPartitioner") + + policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) { + if keyspaceName != keyspace { + return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName) + } + return &KeyspaceMetadata{ + Name: keyspace, + StrategyClass: "NetworkTopologyStrategy", + StrategyOptions: map[string]interface{}{ + "class": "NetworkTopologyStrategy", + "local": 2, + "remote1": 2, + "remote2": 2, + }, + }, nil + } + policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace}) + + // The NetworkTopologyStrategy above should generate the following replicas. + // It's handy to have as reference here. + assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ + keyspace: { + {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3], hosts[4], hosts[5]}}, + {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4], hosts[5], hosts[6]}}, + {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5], hosts[6], hosts[7]}}, + {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5], hosts[6], hosts[7], hosts[8]}}, + {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6], hosts[7], hosts[8], hosts[9]}}, + {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7], hosts[8], hosts[9], hosts[10]}}, + {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8], hosts[9], hosts[10], hosts[11]}}, + {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9], hosts[10], hosts[11], hosts[0]}}, + {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10], hosts[11], hosts[0], hosts[1]}}, + {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11], hosts[0], hosts[1], hosts[2]}}, + {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0], hosts[1], hosts[2], hosts[3]}}, + {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1], hosts[2], hosts[3], hosts[4]}}, + }, + }, policyInternal.getMetadataReadOnly().replicas) + + // now the token ring is configured + query.RoutingKey([]byte("23")) + iter = policy.Pick(query) + // first should be hosts with matching token from the local DC + iterCheck(t, iter, "4") + iterCheck(t, iter, "7") + // rest should be hosts with matching token from remote DCs + iterCheck(t, iter, "3") + iterCheck(t, iter, "5") + iterCheck(t, iter, "6") + iterCheck(t, iter, "8") +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/query_executor.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/query_executor.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/query_executor.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/query_executor.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,14 +1,21 @@ package gocql import ( + "context" "time" ) type ExecutableQuery interface { - execute(conn *Conn) *Iter - attempt(time.Duration) + execute(ctx context.Context, conn *Conn) *Iter + attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) retryPolicy() RetryPolicy + speculativeExecutionPolicy() SpeculativeExecutionPolicy GetRoutingKey() ([]byte, error) + Keyspace() string + IsIdempotent() bool + + withContext(context.Context) ExecutableQuery + RetryableQuery } @@ -17,50 +24,138 @@ policy HostSelectionPolicy } +func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter { + start := time.Now() + iter := qry.execute(ctx, conn) + end := time.Now() + + qry.attempt(q.pool.keyspace, end, start, iter, conn.host) + + return iter +} + +func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, results chan *Iter) *Iter { + ticker := time.NewTicker(sp.Delay()) + defer ticker.Stop() + + for i := 0; i < sp.Attempts(); i++ { + select { + case <-ticker.C: + go q.run(ctx, qry, results) + case <-ctx.Done(): + return &Iter{err: ctx.Err()} + case iter := <-results: + return iter + } + } + + return nil +} + func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { - rt := qry.retryPolicy() + // check if the query is not marked as idempotent, if + // it is, we force the policy to NonSpeculative + sp := qry.speculativeExecutionPolicy() + if !qry.IsIdempotent() || sp.Attempts() == 0 { + return q.do(qry.Context(), qry), nil + } + + ctx, cancel := context.WithCancel(qry.Context()) + defer cancel() + + results := make(chan *Iter, 1) + + // Launch the main execution + go q.run(ctx, qry, results) + + // The speculative executions are launched _in addition_ to the main + // execution, on a timer. So Speculation{2} would make 3 executions running + // in total. + if iter := q.speculate(ctx, qry, sp, results); iter != nil { + return iter, nil + } + + select { + case iter := <-results: + return iter, nil + case <-ctx.Done(): + return &Iter{err: ctx.Err()}, nil + } +} + +func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter { hostIter := q.policy.Pick(qry) + selectedHost := hostIter() + rt := qry.retryPolicy() + var lastErr error var iter *Iter - for hostResponse := hostIter(); hostResponse != nil; hostResponse = hostIter() { - host := hostResponse.Info() + for selectedHost != nil { + host := selectedHost.Info() if host == nil || !host.IsUp() { + selectedHost = hostIter() continue } pool, ok := q.pool.getPool(host) if !ok { + selectedHost = hostIter() continue } conn := pool.Pick() if conn == nil { + selectedHost = hostIter() continue } - start := time.Now() - iter = qry.execute(conn) - - qry.attempt(time.Since(start)) - + iter = q.attemptQuery(ctx, qry, conn) + iter.host = selectedHost.Info() // Update host - hostResponse.Mark(iter.err) + switch iter.err { + case context.Canceled, context.DeadlineExceeded, ErrNotFound: + // those errors represents logical errors, they should not count + // toward removing a node from the pool + selectedHost.Mark(nil) + return iter + default: + selectedHost.Mark(iter.err) + } - // Exit for loop if the query was successful - if iter.err == nil { - iter.host = host - return iter, nil + // Exit if the query was successful + // or no retry policy defined or retry attempts were reached + if iter.err == nil || rt == nil || !rt.Attempt(qry) { + return iter } + lastErr = iter.err - if rt == nil || !rt.Attempt(qry) { - // What do here? Should we just return an error here? - break + // If query is unsuccessful, check the error with RetryPolicy to retry + switch rt.GetRetryType(iter.err) { + case Retry: + // retry on the same host + continue + case Rethrow, Ignore: + return iter + case RetryNextHost: + // retry on the next host + selectedHost = hostIter() + continue + default: + // Undefined? Return nil and error, this will panic in the requester + return &Iter{err: ErrUnknownRetryType} } } - if iter == nil { - return nil, ErrNoConnections + if lastErr != nil { + return &Iter{err: lastErr} } - return iter, nil + return &Iter{err: ErrNoConnections} +} + +func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan<- *Iter) { + select { + case results <- q.do(ctx, qry): + case <-ctx.Done(): + } } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/README.md golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/README.md --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/README.md 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/README.md 2019-11-02 13:15:23.000000000 +0000 @@ -17,10 +17,10 @@ The following matrix shows the versions of Go and Cassandra that are tested with the integration test suite as part of the CI build: -Go/Cassandra | 2.1.x | 2.2.x | 3.0.x +Go/Cassandra | 2.1.x | 2.2.x | 3.x.x -------------| -------| ------| --------- -1.8 | yes | yes | yes -1.9 | yes | yes | yes +1.12 | yes | yes | yes +1.13 | yes | yes | yes Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project. @@ -195,6 +195,7 @@ * [gocqltable](https://github.com/kristoiv/gocqltable) is a wrapper around gocql that aims to simplify common operations. * [gockle](https://github.com/willfaught/gockle) provides simple, mockable interfaces that wrap gocql types * [scylladb](https://github.com/scylladb/scylla) is a fast Apache Cassandra-compatible NoSQL database +* [go-cql-driver](https://github.com/MichaelS11/go-cql-driver) is an CQL driver conforming to the built-in database/sql interface. It is good for simple use cases where the database/sql interface is wanted. The CQL driver is a wrapper around this project. Other Projects -------------- diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/ring.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/ring.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/ring.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/ring.go 2019-11-02 13:15:23.000000000 +0000 @@ -64,6 +64,8 @@ } func (r *ring) addHost(host *HostInfo) bool { + // TODO(zariel): key all host info by HostID instead of + // ip addresses if host.invalidConnectAddr() { panic(fmt.Sprintf("invalid host: %v", host)) } @@ -140,8 +142,8 @@ } func (c *clusterMetadata) setPartitioner(partitioner string) { - c.mu.RLock() - defer c.mu.RUnlock() + c.mu.Lock() + defer c.mu.Unlock() if c.partitioner != partitioner { // TODO: update other things now diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/session_connect_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/session_connect_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/session_connect_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/session_connect_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -91,41 +91,3 @@ } } } - -func TestSession_connect_WithNoTranslator(t *testing.T) { - srvr, err := NewOneConnTestServer() - assertNil(t, "error when creating tcp server", err) - defer srvr.Close() - - session := createTestSession() - defer session.Close() - - go srvr.Serve() - - Connect(&HostInfo{ - connectAddress: srvr.Addr, - port: srvr.Port, - }, session.connCfg, testConnErrorHandler(t), session) - - assertConnectionEventually(t, 500*time.Millisecond, srvr) -} - -func TestSession_connect_WithTranslator(t *testing.T) { - srvr, err := NewOneConnTestServer() - assertNil(t, "error when creating tcp server", err) - defer srvr.Close() - - session := createTestSession() - defer session.Close() - session.cfg.AddressTranslator = staticAddressTranslator(srvr.Addr, srvr.Port) - - go srvr.Serve() - - // the provided address will be translated - Connect(&HostInfo{ - connectAddress: net.ParseIP("10.10.10.10"), - port: 5432, - }, session.connCfg, testConnErrorHandler(t), session) - - assertConnectionEventually(t, 500*time.Millisecond, srvr) -} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/session.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/session.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/session.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/session.go 2019-11-02 13:15:23.000000000 +0000 @@ -37,6 +37,10 @@ routingKeyInfoCache routingKeyInfoLRU schemaDescriber *schemaDescriber trace Tracer + queryObserver QueryObserver + batchObserver BatchObserver + connectObserver ConnectObserver + frameObserver FrameHeaderObserver hostSource *ringDescriber stmtsLRU *preparedLRU @@ -58,8 +62,9 @@ schemaEvents *eventDebouncer // ring metadata - hosts []HostInfo - useSystemSchema bool + hosts []HostInfo + useSystemSchema bool + hasAggregatesAndFunctions bool cfg ClusterConfig @@ -78,7 +83,7 @@ func addrsToHosts(addrs []string, defaultPort int) ([]*HostInfo, error) { var hosts []*HostInfo for _, hostport := range addrs { - host, err := hostInfo(hostport, defaultPort) + resolvedHosts, err := hostInfo(hostport, defaultPort) if err != nil { // Try other hosts if unable to resolve DNS name if _, ok := err.(*net.DNSError); ok { @@ -88,7 +93,7 @@ return nil, err } - hosts = append(hosts, host) + hosts = append(hosts, resolvedHosts...) } if len(hosts) == 0 { return nil, errors.New("failed to resolve any of the provided hostnames") @@ -103,23 +108,29 @@ return nil, ErrNoHosts } + // Check that either Authenticator is set or AuthProvider, not both + if cfg.Authenticator != nil && cfg.AuthProvider != nil { + return nil, errors.New("Can't use both Authenticator and AuthProvider in cluster config.") + } + s := &Session{ - cons: cfg.Consistency, - prefetch: 0.25, - cfg: cfg, - pageSize: cfg.PageSize, - stmtsLRU: &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)}, - quit: make(chan struct{}), + cons: cfg.Consistency, + prefetch: 0.25, + cfg: cfg, + pageSize: cfg.PageSize, + stmtsLRU: &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)}, + quit: make(chan struct{}), + connectObserver: cfg.ConnectObserver, } + s.schemaDescriber = newSchemaDescriber(s) + s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent) s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent) s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo) - s.hostSource = &ringDescriber{ - session: s, - } + s.hostSource = &ringDescriber{session: s} if cfg.PoolConfig.HostSelectionPolicy == nil { cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() @@ -127,11 +138,18 @@ s.pool = cfg.PoolConfig.buildPool(s) s.policy = cfg.PoolConfig.HostSelectionPolicy + s.policy.Init(s) + s.executor = &queryExecutor{ pool: s.pool, policy: cfg.PoolConfig.HostSelectionPolicy, } + s.queryObserver = cfg.QueryObserver + s.batchObserver = cfg.BatchObserver + s.connectObserver = cfg.ConnectObserver + s.frameObserver = cfg.FrameHeaderObserver + //Check the TLS Config before trying to connect to anything external connCfg, err := connConfig(&s.cfg) if err != nil { @@ -160,6 +178,7 @@ if err != nil { return err } + s.ring.endpoints = hosts if !s.cfg.disableControlConn { s.control = createControlConn(s) @@ -182,17 +201,48 @@ if !s.cfg.DisableInitialHostLookup { var partitioner string - hosts, partitioner, err = s.hostSource.GetHosts() + newHosts, partitioner, err := s.hostSource.GetHosts() if err != nil { return err } s.policy.SetPartitioner(partitioner) + filteredHosts := make([]*HostInfo, 0, len(newHosts)) + for _, host := range newHosts { + if !s.cfg.filterHost(host) { + filteredHosts = append(filteredHosts, host) + } + } + hosts = append(hosts, filteredHosts...) } } + hostMap := make(map[string]*HostInfo, len(hosts)) for _, host := range hosts { + hostMap[host.ConnectAddress().String()] = host + } + + hosts = hosts[:0] + for _, host := range hostMap { host = s.ring.addOrUpdate(host) - s.handleNodeUp(host.ConnectAddress(), host.Port(), false) + if s.cfg.filterHost(host) { + continue + } + + host.setState(NodeUp) + s.pool.addHost(host) + + hosts = append(hosts, host) + } + + type bulkAddHosts interface { + AddHosts([]*HostInfo) + } + if v, ok := s.policy.(bulkAddHosts); ok { + v.AddHosts(hosts) + } else { + for _, host := range hosts { + s.policy.AddHost(host) + } } // TODO(zariel): we probably dont need this any more as we verify that we @@ -210,13 +260,21 @@ newer, _ := checkSystemSchema(s.control) s.useSystemSchema = newer } else { - s.useSystemSchema = hosts[0].Version().Major >= 3 + version := s.ring.rrHost().Version() + s.useSystemSchema = version.AtLeast(3, 0, 0) + s.hasAggregatesAndFunctions = version.AtLeast(2, 2, 0) } if s.pool.Size() == 0 { return ErrNoConnectionsStarted } + // Invoke KeyspaceChanged to let the policy cache the session keyspace + // parameters. This is used by tokenAwareHostPolicy to discover replicas. + if !s.cfg.disableControlConn && s.cfg.Keyspace != "" { + s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: s.cfg.Keyspace}) + } + return nil } @@ -290,19 +348,11 @@ // value before the query is executed. Query is automatically prepared // if it has not previously been executed. func (s *Session) Query(stmt string, values ...interface{}) *Query { - s.mu.RLock() qry := queryPool.Get().(*Query) + qry.session = s qry.stmt = stmt qry.values = values - qry.cons = s.cons - qry.session = s - qry.pageSize = s.pageSize - qry.trace = s.trace - qry.prefetch = s.prefetch - qry.rt = s.cfg.RetryPolicy - qry.serialCons = s.cfg.SerialConsistency - qry.defaultTimestamp = s.cfg.DefaultTimestamp - s.mu.RUnlock() + qry.defaultsFromSession() return qry } @@ -320,11 +370,11 @@ // During execution, the meta data of the prepared query will be routed to the // binding callback, which is responsible for producing the query argument values. func (s *Session) Bind(stmt string, b func(q *QueryInfo) ([]interface{}, error)) *Query { - s.mu.RLock() - qry := &Query{stmt: stmt, binding: b, cons: s.cons, - session: s, pageSize: s.pageSize, trace: s.trace, - prefetch: s.prefetch, rt: s.cfg.RetryPolicy} - s.mu.RUnlock() + qry := queryPool.Get().(*Query) + qry.session = s + qry.stmt = stmt + qry.binding = b + qry.defaultsFromSession() return qry } @@ -367,7 +417,7 @@ return closed } -func (s *Session) executeQuery(qry *Query) *Iter { +func (s *Session) executeQuery(qry *Query) (it *Iter) { // fail fast if s.Closed() { return &Iter{err: ErrSessionClosed} @@ -395,25 +445,15 @@ // fail fast if s.Closed() { return nil, ErrSessionClosed - } - - if keyspace == "" { + } else if keyspace == "" { return nil, ErrNoKeyspace } - s.mu.Lock() - // lazy-init schemaDescriber - if s.schemaDescriber == nil { - s.schemaDescriber = newSchemaDescriber(s) - } - s.mu.Unlock() - return s.schemaDescriber.getSchema(keyspace) } func (s *Session) getConn() *Conn { hosts := s.ring.allHosts() - var conn *Conn for _, host := range hosts { if !host.IsUp() { continue @@ -422,10 +462,7 @@ pool, ok := s.pool.getPool(host) if !ok { continue - } - - conn = pool.Pick() - if conn != nil { + } else if conn := pool.Pick(); conn != nil { return conn } } @@ -565,8 +602,8 @@ return routingKeyInfo, nil } -func (b *Batch) execute(conn *Conn) *Iter { - return conn.executeBatch(b) +func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { + return conn.executeBatch(ctx, b) } func (s *Session) executeBatch(batch *Batch) *Iter { @@ -638,8 +675,90 @@ return applied, iter, iter.err } -func (s *Session) connect(host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) { - return Connect(host, s.connCfg, errorHandler, s) +type hostMetrics struct { + Attempts int + TotalLatency int64 +} + +type queryMetrics struct { + l sync.RWMutex + m map[string]*hostMetrics + // totalAttempts is total number of attempts. + // Equal to sum of all hostMetrics' Attempts. + totalAttempts int +} + +// preFilledQueryMetrics initializes new queryMetrics based on per-host supplied data. +func preFilledQueryMetrics(m map[string]*hostMetrics) *queryMetrics { + qm := &queryMetrics{m: m} + for _, hm := range qm.m { + qm.totalAttempts += hm.Attempts + } + return qm +} + +// hostMetricsLocked gets or creates host metrics for given host. +func (qm *queryMetrics) hostMetrics(host *HostInfo) *hostMetrics { + qm.l.Lock() + metrics := qm.hostMetricsLocked(host) + qm.l.Unlock() + return metrics +} + +// hostMetricsLocked gets or creates host metrics for given host. +// It must be called only while holding qm.l lock. +func (qm *queryMetrics) hostMetricsLocked(host *HostInfo) *hostMetrics { + metrics, exists := qm.m[host.ConnectAddress().String()] + if !exists { + // if the host is not in the map, it means it's been accessed for the first time + metrics = &hostMetrics{} + qm.m[host.ConnectAddress().String()] = metrics + } + + return metrics +} + +// attempts returns the number of times the query was executed. +func (qm *queryMetrics) attempts() int { + qm.l.Lock() + attempts := qm.totalAttempts + qm.l.Unlock() + return attempts +} + +// addAttempts adds given number of attempts and returns previous total attempts. +func (qm *queryMetrics) addAttempts(i int, host *HostInfo) int { + qm.l.Lock() + hostMetric := qm.hostMetricsLocked(host) + hostMetric.Attempts += i + attempts := qm.totalAttempts + qm.totalAttempts += i + qm.l.Unlock() + return attempts +} + +func (qm *queryMetrics) latency() int64 { + qm.l.Lock() + var ( + attempts int + latency int64 + ) + for _, metric := range qm.m { + attempts += metric.Attempts + latency += metric.TotalLatency + } + qm.l.Unlock() + if attempts > 0 { + return latency / int64(attempts) + } + return 0 +} + +func (qm *queryMetrics) addLatency(l int64, host *HostInfo) { + qm.l.Lock() + hostMetric := qm.hostMetricsLocked(host) + hostMetric.TotalLatency += l + qm.l.Unlock() } // Query represents a CQL statement that can be executed. @@ -649,22 +768,51 @@ cons Consistency pageSize int routingKey []byte - routingKeyBuffer []byte pageState []byte prefetch float64 trace Tracer + observer QueryObserver session *Session rt RetryPolicy + spec SpeculativeExecutionPolicy binding func(q *QueryInfo) ([]interface{}, error) - attempts int - totalLatency int64 serialCons SerialConsistency defaultTimestamp bool defaultTimestampValue int64 disableSkipMetadata bool context context.Context + idempotent bool + customPayload map[string][]byte + metrics *queryMetrics disableAutoPage bool + + // getKeyspace is field so that it can be overriden in tests + getKeyspace func() string +} + +func (q *Query) defaultsFromSession() { + s := q.session + + s.mu.RLock() + q.cons = s.cons + q.pageSize = s.pageSize + q.trace = s.trace + q.observer = s.queryObserver + q.prefetch = s.prefetch + q.rt = s.cfg.RetryPolicy + q.serialCons = s.cfg.SerialConsistency + q.defaultTimestamp = s.cfg.DefaultTimestamp + q.idempotent = s.cfg.DefaultIdempotence + q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} + + q.spec = &NonSpeculativeExecution{} + s.mu.RUnlock() +} + +// Statement returns the statement that was used to generate this query. +func (q Query) Statement() string { + return q.stmt } // String implements the stringer interface. @@ -674,15 +822,20 @@ //Attempts returns the number of times the query was executed. func (q *Query) Attempts() int { - return q.attempts + return q.metrics.attempts() +} + +func (q *Query) AddAttempts(i int, host *HostInfo) { + q.metrics.addAttempts(i, host) } //Latency returns the average amount of nanoseconds per attempt of the query. func (q *Query) Latency() int64 { - if q.attempts > 0 { - return q.totalLatency / int64(q.attempts) - } - return 0 + return q.metrics.latency() +} + +func (q *Query) AddLatency(l int64, host *HostInfo) { + q.metrics.addLatency(l, host) } // Consistency sets the consistency level for this query. If no consistency @@ -699,6 +852,24 @@ return q.cons } +// Same as Consistency but without a return value +func (q *Query) SetConsistency(c Consistency) { + q.cons = c +} + +// CustomPayload sets the custom payload level for this query. +func (q *Query) CustomPayload(customPayload map[string][]byte) *Query { + q.customPayload = customPayload + return q +} + +func (q *Query) Context() context.Context { + if q.context == nil { + return context.Background() + } + return q.context +} + // Trace enables tracing of this query. Look at the documentation of the // Tracer interface to learn more about tracing. func (q *Query) Trace(trace Tracer) *Query { @@ -706,6 +877,13 @@ return q } +// Observer enables query-level observer on this query. +// The provided observer will be called every time this query is executed. +func (q *Query) Observer(observer QueryObserver) *Query { + q.observer = observer + return q +} + // PageSize will tell the iterator to fetch the result in pages of size n. // This is useful for iterating over large result sets, but setting the // page size too low might decrease the performance. This feature is only @@ -745,27 +923,68 @@ return q } -// WithContext will set the context to use during a query, it will be used to -// timeout when waiting for responses from Cassandra. -func (q *Query) WithContext(ctx context.Context) *Query { - q.context = ctx - return q -} - -func (q *Query) execute(conn *Conn) *Iter { - return conn.executeQuery(q) +func (q *Query) withContext(ctx context.Context) ExecutableQuery { + // I really wish go had covariant types + return q.WithContext(ctx) } -func (q *Query) attempt(d time.Duration) { - q.attempts++ - q.totalLatency += d.Nanoseconds() - // TODO: track latencies per host and things as well instead of just total +// WithContext returns a shallow copy of q with its context +// set to ctx. +// +// The provided context controls the entire lifetime of executing a +// query, queries will be canceled and return once the context is +// canceled. +func (q *Query) WithContext(ctx context.Context) *Query { + q2 := *q + q2.context = ctx + return &q2 +} + +// Deprecate: does nothing, cancel the context passed to WithContext +func (q *Query) Cancel() { + // TODO: delete +} + +func (q *Query) execute(ctx context.Context, conn *Conn) *Iter { + return conn.executeQuery(ctx, q) +} + +func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { + attempt := q.metrics.addAttempts(1, host) + q.AddLatency(end.Sub(start).Nanoseconds(), host) + + if q.observer != nil { + q.observer.ObserveQuery(q.Context(), ObservedQuery{ + Keyspace: keyspace, + Statement: q.stmt, + Start: start, + End: end, + Rows: iter.numRows, + Host: host, + Metrics: q.metrics.hostMetrics(host), + Err: iter.err, + Attempt: attempt, + }) + } } func (q *Query) retryPolicy() RetryPolicy { return q.rt } +// Keyspace returns the keyspace the query will be executed against. +func (q *Query) Keyspace() string { + if q.getKeyspace != nil { + return q.getKeyspace() + } + if q.session == nil { + return "" + } + // TODO(chbannis): this should be parsed from the query or we should let + // this be set by users. + return q.session.cfg.Keyspace +} + // GetRoutingKey gets the routing key to use for routing this query. If // a routing key has not been explicitly set, then the routing key will // be constructed if possible using the keyspace's schema and the query @@ -783,51 +1002,12 @@ } // try to determine the routing key - routingKeyInfo, err := q.session.routingKeyInfo(q.context, q.stmt) + routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt) if err != nil { return nil, err } - if routingKeyInfo == nil { - return nil, nil - } - - if len(routingKeyInfo.indexes) == 1 { - // single column routing key - routingKey, err := Marshal( - routingKeyInfo.types[0], - q.values[routingKeyInfo.indexes[0]], - ) - if err != nil { - return nil, err - } - return routingKey, nil - } - - // We allocate that buffer only once, so that further re-bind/exec of the - // same query don't allocate more memory. - if q.routingKeyBuffer == nil { - q.routingKeyBuffer = make([]byte, 0, 256) - } - - // composite routing key - buf := bytes.NewBuffer(q.routingKeyBuffer) - for i := range routingKeyInfo.indexes { - encoded, err := Marshal( - routingKeyInfo.types[i], - q.values[routingKeyInfo.indexes[i]], - ) - if err != nil { - return nil, err - } - lenBuf := []byte{0x00, 0x00} - binary.BigEndian.PutUint16(lenBuf, uint16(len(encoded))) - buf.Write(lenBuf) - buf.Write(encoded) - buf.WriteByte(0x00) - } - routingKey := buf.Bytes() - return routingKey, nil + return createRoutingKey(routingKeyInfo, q.values) } func (q *Query) shouldPrepare() bool { @@ -866,10 +1046,33 @@ return q } +// SetSpeculativeExecutionPolicy sets the execution policy +func (q *Query) SetSpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Query { + q.spec = sp + return q +} + +// speculativeExecutionPolicy fetches the policy +func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy { + return q.spec +} + +func (q *Query) IsIdempotent() bool { + return q.idempotent +} + +// Idempotent marks the query as being idempotent or not depending on +// the value. +func (q *Query) Idempotent(value bool) *Query { + q.idempotent = value + return q +} + // Bind sets query arguments of query. This can also be used to rebind new query arguments // to an existing query instance. func (q *Query) Bind(v ...interface{}) *Query { q.values = v + q.pageState = nil return q } @@ -915,7 +1118,7 @@ return false } - return strings.ToLower(stmt[0:3]) == "use" + return strings.EqualFold(stmt[0:3], "use") } // Iter executes the query and returns an iterator capable of iterating @@ -1005,25 +1208,7 @@ // reset zeroes out all fields of a query so that it can be safely pooled. func (q *Query) reset() { - q.stmt = "" - q.values = nil - q.cons = 0 - q.pageSize = 0 - q.routingKey = nil - q.routingKeyBuffer = nil - q.pageState = nil - q.prefetch = 0 - q.trace = nil - q.session = nil - q.rt = nil - q.binding = nil - q.attempts = 0 - q.totalLatency = 0 - q.serialCons = 0 - q.defaultTimestamp = false - q.disableSkipMetadata = false - q.disableAutoPage = false - q.context = nil + *q = Query{} } // Iter represents an iterator that can be used to iterate over all rows that @@ -1057,22 +1242,23 @@ // scanned into with Scan. // Next must be called before every call to Scan. Next() bool - + // Scan copies the current row's columns into dest. If the length of dest does not equal // the number of columns returned in the row an error is returned. If an error is encountered // when unmarshalling a column into the value in dest an error is returned and the row is invalidated // until the next call to Next. // Next must be called before calling Scan, if it is not an error is returned. Scan(...interface{}) error - + // Err returns the if there was one during iteration that resulted in iteration being unable to complete. // Err will also release resources held by the iterator, the Scanner should not used after being called. Err() error } type iterScanner struct { - iter *Iter - cols [][]byte + iter *Iter + cols [][]byte + valid bool } func (is *iterScanner) Next() bool { @@ -1089,17 +1275,16 @@ return false } - cols := make([][]byte, len(iter.meta.columns)) - for i := 0; i < len(cols); i++ { + for i := 0; i < len(is.cols); i++ { col, err := iter.readColumn() if err != nil { iter.err = err return false } - cols[i] = col + is.cols[i] = col } - is.cols = cols iter.pos++ + is.valid = true return true } @@ -1129,7 +1314,7 @@ } func (is *iterScanner) Scan(dest ...interface{}) error { - if is.cols == nil { + if !is.valid { return errors.New("gocql: Scan called without calling Next") } @@ -1153,8 +1338,7 @@ i += n } - is.cols = nil - + is.valid = false return err } @@ -1162,6 +1346,7 @@ iter := is.iter is.iter = nil is.cols = nil + is.valid = false return iter.Close() } @@ -1172,7 +1357,7 @@ return nil } - return &iterScanner{iter: iter} + return &iterScanner{iter: iter, cols: make([][]byte, len(iter.meta.columns))} } func (iter *Iter) readColumn() ([]byte, error) { @@ -1200,7 +1385,7 @@ return false } - if iter.next != nil && iter.pos == iter.next.pos { + if iter.next != nil && iter.pos >= iter.next.pos { go iter.next.fetch() } @@ -1241,7 +1426,7 @@ // custom QueryHandlers running in your C* cluster. // See https://datastax.github.io/java-driver/manual/custom_payloads/ func (iter *Iter) GetCustomPayload() map[string][]byte { - return iter.framer.header.customPayload + return iter.framer.customPayload } // Warnings returns any warnings generated if given in the response from Cassandra. @@ -1259,7 +1444,6 @@ func (iter *Iter) Close() error { if atomic.CompareAndSwapInt32(&iter.closed, 0, 1) { if iter.framer != nil { - framerPool.Put(iter.framer) iter.framer = nil } } @@ -1284,7 +1468,7 @@ } // PageState return the current paging state for a query which can be used for -// subsequent quries to resume paging this point. +// subsequent queries to resume paging this point. func (iter *Iter) PageState() []byte { return iter.meta.pagingState } @@ -1297,7 +1481,7 @@ } type nextIter struct { - qry Query + qry *Query pos int once sync.Once next *Iter @@ -1305,7 +1489,7 @@ func (n *nextIter) fetch() *Iter { n.once.Do(func() { - n.next = n.qry.session.executeQuery(&n.qry) + n.next = n.qry.session.executeQuery(n.qry) }) return n.next } @@ -1314,40 +1498,80 @@ Type BatchType Entries []BatchEntry Cons Consistency + routingKey []byte + routingKeyBuffer []byte + CustomPayload map[string][]byte rt RetryPolicy - attempts int - totalLatency int64 + spec SpeculativeExecutionPolicy + observer BatchObserver + session *Session serialCons SerialConsistency defaultTimestamp bool defaultTimestampValue int64 context context.Context + cancelBatch func() + keyspace string + metrics *queryMetrics } // NewBatch creates a new batch operation without defaults from the cluster +// +// Deprecated: use session.NewBatch instead func NewBatch(typ BatchType) *Batch { - return &Batch{Type: typ} + return &Batch{ + Type: typ, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + spec: &NonSpeculativeExecution{}, + } } // NewBatch creates a new batch operation using defaults defined in the cluster func (s *Session) NewBatch(typ BatchType) *Batch { s.mu.RLock() - batch := &Batch{Type: typ, rt: s.cfg.RetryPolicy, serialCons: s.cfg.SerialConsistency, - Cons: s.cons, defaultTimestamp: s.cfg.DefaultTimestamp} + batch := &Batch{ + Type: typ, + rt: s.cfg.RetryPolicy, + serialCons: s.cfg.SerialConsistency, + observer: s.batchObserver, + session: s, + Cons: s.cons, + defaultTimestamp: s.cfg.DefaultTimestamp, + keyspace: s.cfg.Keyspace, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + spec: &NonSpeculativeExecution{}, + } + s.mu.RUnlock() return batch } +// Observer enables batch-level observer on this batch. +// The provided observer will be called every time this batched query is executed. +func (b *Batch) Observer(observer BatchObserver) *Batch { + b.observer = observer + return b +} + +func (b *Batch) Keyspace() string { + return b.keyspace +} + // Attempts returns the number of attempts made to execute the batch. func (b *Batch) Attempts() int { - return b.attempts + return b.metrics.attempts() +} + +func (b *Batch) AddAttempts(i int, host *HostInfo) { + b.metrics.addAttempts(i, host) } //Latency returns the average number of nanoseconds to execute a single attempt of the batch. func (b *Batch) Latency() int64 { - if b.attempts > 0 { - return b.totalLatency / int64(b.attempts) - } - return 0 + return b.metrics.latency() +} + +func (b *Batch) AddLatency(l int64, host *HostInfo) { + b.metrics.addLatency(l, host) } // GetConsistency returns the currently configured consistency level for the batch @@ -1356,6 +1580,37 @@ return b.Cons } +// SetConsistency sets the currently configured consistency level for the batch +// operation. +func (b *Batch) SetConsistency(c Consistency) { + b.Cons = c +} + +func (b *Batch) Context() context.Context { + if b.context == nil { + return context.Background() + } + return b.context +} + +func (b *Batch) IsIdempotent() bool { + for _, entry := range b.Entries { + if !entry.Idempotent { + return false + } + } + return true +} + +func (b *Batch) speculativeExecutionPolicy() SpeculativeExecutionPolicy { + return b.spec +} + +func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch { + b.spec = sp + return b +} + // Query adds the query to the batch operation func (b *Batch) Query(stmt string, args ...interface{}) { b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args}) @@ -1378,11 +1633,25 @@ return b } -// WithContext will set the context to use during a query, it will be used to -// timeout when waiting for responses from Cassandra. +func (b *Batch) withContext(ctx context.Context) ExecutableQuery { + return b.WithContext(ctx) +} + +// WithContext returns a shallow copy of b with its context +// set to ctx. +// +// The provided context controls the entire lifetime of executing a +// query, queries will be canceled and return once the context is +// canceled. func (b *Batch) WithContext(ctx context.Context) *Batch { - b.context = ctx - return b + b2 := *b + b2.context = ctx + return &b2 +} + +// Deprecate: does nothing, cancel the context passed to WithContext +func (*Batch) Cancel() { + // TODO: delete } // Size returns the number of batch statements to be executed by the batch operation. @@ -1425,15 +1694,89 @@ return b } -func (b *Batch) attempt(d time.Duration) { - b.attempts++ - b.totalLatency += d.Nanoseconds() - // TODO: track latencies per host and things as well instead of just total +func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { + b.AddAttempts(1, host) + b.AddLatency(end.Sub(start).Nanoseconds(), host) + + if b.observer == nil { + return + } + + statements := make([]string, len(b.Entries)) + for i, entry := range b.Entries { + statements[i] = entry.Stmt + } + + b.observer.ObserveBatch(b.Context(), ObservedBatch{ + Keyspace: keyspace, + Statements: statements, + Start: start, + End: end, + // Rows not used in batch observations // TODO - might be able to support it when using BatchCAS + Host: host, + Metrics: b.metrics.hostMetrics(host), + Err: iter.err, + }) } func (b *Batch) GetRoutingKey() ([]byte, error) { - // TODO: use the first statement in the batch as the routing key? - return nil, nil + if b.routingKey != nil { + return b.routingKey, nil + } + + if len(b.Entries) == 0 { + return nil, nil + } + + entry := b.Entries[0] + if entry.binding != nil { + // bindings do not have the values let's skip it like Query does. + return nil, nil + } + // try to determine the routing key + routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt) + if err != nil { + return nil, err + } + + return createRoutingKey(routingKeyInfo, entry.Args) +} + +func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]byte, error) { + if routingKeyInfo == nil { + return nil, nil + } + + if len(routingKeyInfo.indexes) == 1 { + // single column routing key + routingKey, err := Marshal( + routingKeyInfo.types[0], + values[routingKeyInfo.indexes[0]], + ) + if err != nil { + return nil, err + } + return routingKey, nil + } + + // composite routing key + buf := bytes.NewBuffer(make([]byte, 0, 256)) + for i := range routingKeyInfo.indexes { + encoded, err := Marshal( + routingKeyInfo.types[i], + values[routingKeyInfo.indexes[i]], + ) + if err != nil { + return nil, err + } + lenBuf := []byte{0x00, 0x00} + binary.BigEndian.PutUint16(lenBuf, uint16(len(encoded))) + buf.Write(lenBuf) + buf.Write(encoded) + buf.WriteByte(0x00) + } + routingKey := buf.Bytes() + return routingKey, nil } type BatchType byte @@ -1445,9 +1788,10 @@ ) type BatchEntry struct { - Stmt string - Args []interface{} - binding func(q *QueryInfo) ([]interface{}, error) + Stmt string + Args []interface{} + Idempotent bool + binding func(q *QueryInfo) ([]interface{}, error) } type ColumnInfo struct { @@ -1544,12 +1888,12 @@ elapsed int ) - fmt.Fprintf(t.w, "Tracing session %016x (coordinator: %s, duration: %v):\n", - traceId, coordinator, time.Duration(duration)*time.Microsecond) - t.mu.Lock() defer t.mu.Unlock() + fmt.Fprintf(t.w, "Tracing session %016x (coordinator: %s, duration: %v):\n", + traceId, coordinator, time.Duration(duration)*time.Microsecond) + iter = t.session.control.query(`SELECT event_id, activity, source, source_elapsed FROM system_traces.events WHERE session_id = ?`, traceId) @@ -1564,6 +1908,88 @@ } } +type ObservedQuery struct { + Keyspace string + Statement string + + Start time.Time // time immediately before the query was called + End time.Time // time immediately after the query returned + + // Rows is the number of rows in the current iter. + // In paginated queries, rows from previous scans are not counted. + // Rows is not used in batch queries and remains at the default value + Rows int + + // Host is the informations about the host that performed the query + Host *HostInfo + + // The metrics per this host + Metrics *hostMetrics + + // Err is the error in the query. + // It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error + Err error + + // Attempt is the index of attempt at executing this query. + // The first attempt is number zero and any retries have non-zero attempt number. + Attempt int +} + +// QueryObserver is the interface implemented by query observers / stat collectors. +// +// Experimental, this interface and use may change +type QueryObserver interface { + // ObserveQuery gets called on every query to cassandra, including all queries in an iterator when paging is enabled. + // It doesn't get called if there is no query because the session is closed or there are no connections available. + // The error reported only shows query errors, i.e. if a SELECT is valid but finds no matches it will be nil. + ObserveQuery(context.Context, ObservedQuery) +} + +type ObservedBatch struct { + Keyspace string + Statements []string + + Start time.Time // time immediately before the batch query was called + End time.Time // time immediately after the batch query returned + + // Host is the informations about the host that performed the batch + Host *HostInfo + + // Err is the error in the batch query. + // It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error + Err error + + // The metrics per this host + Metrics *hostMetrics +} + +// BatchObserver is the interface implemented by batch observers / stat collectors. +type BatchObserver interface { + // ObserveBatch gets called on every batch query to cassandra. + // It also gets called once for each query in a batch. + // It doesn't get called if there is no query because the session is closed or there are no connections available. + // The error reported only shows query errors, i.e. if a SELECT is valid but finds no matches it will be nil. + // Unlike QueryObserver.ObserveQuery it does no reporting on rows read. + ObserveBatch(context.Context, ObservedBatch) +} + +type ObservedConnect struct { + // Host is the information about the host about to connect + Host *HostInfo + + Start time.Time // time immediately before the dial is called + End time.Time // time immediately after the dial returned + + // Err is the connection error (if any) + Err error +} + +// ConnectObserver is the interface implemented by connect observers / stat collectors. +type ConnectObserver interface { + // ObserveConnect gets called when a new connection to cassandra is made. + ObserveConnect(ObservedConnect) +} + type Error struct { Code int Message string diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/session_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/session_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/session_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/session_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,9 +1,11 @@ -// +build all integration +// +build all cassandra package gocql import ( + "context" "fmt" + "net" "testing" ) @@ -89,15 +91,24 @@ } } +type funcQueryObserver func(context.Context, ObservedQuery) + +func (f funcQueryObserver) ObserveQuery(ctx context.Context, o ObservedQuery) { + f(ctx, o) +} + func TestQueryBasicAPI(t *testing.T) { qry := &Query{} + // Initiate host + ip := "127.0.0.1" + + qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 0, TotalLatency: 0}}) if qry.Latency() != 0 { t.Fatalf("expected Query.Latency() to return 0, got %v", qry.Latency()) } - qry.attempts = 2 - qry.totalLatency = 4 + qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 2, TotalLatency: 4}}) if qry.Attempts() != 2 { t.Fatalf("expected Query.Attempts() to return 2, got %v", qry.Attempts()) } @@ -105,6 +116,11 @@ t.Fatalf("expected Query.Latency() to return 2, got %v", qry.Latency()) } + qry.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) + if qry.Attempts() != 4 { + t.Fatalf("expected Query.Attempts() to return 4, got %v", qry.Attempts()) + } + qry.Consistency(All) if qry.GetConsistency() != All { t.Fatalf("expected Query.GetConsistency to return 'All', got '%s'", qry.GetConsistency()) @@ -116,6 +132,12 @@ t.Fatalf("expected Query.Trace to be '%v', got '%v'", trace, qry.trace) } + observer := funcQueryObserver(func(context.Context, ObservedQuery) {}) + qry.Observer(observer) + if qry.observer == nil { // can't compare func to func, checking not nil instead + t.Fatal("expected Query.QueryObserver to be set, got nil") + } + qry.PageSize(10) if qry.pageSize != 10 { t.Fatalf("expected Query.PageSize to be 10, got %v", qry.pageSize) @@ -169,6 +191,7 @@ s.pool = cfg.PoolConfig.buildPool(s) + // Test UnloggedBatch b := s.NewBatch(UnloggedBatch) if b.Type != UnloggedBatch { t.Fatalf("expceted batch.Type to be '%v', got '%v'", UnloggedBatch, b.Type) @@ -176,33 +199,45 @@ t.Fatalf("expceted batch.RetryPolicy to be '%v', got '%v'", cfg.RetryPolicy, b.rt) } - b = NewBatch(LoggedBatch) + // Test LoggedBatch + b = s.NewBatch(LoggedBatch) if b.Type != LoggedBatch { t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type) } - b.attempts = 1 + ip := "127.0.0.1" + + // Test attempts + b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1}}) if b.Attempts() != 1 { - t.Fatalf("expceted batch.Attempts() to return %v, got %v", 1, b.Attempts()) + t.Fatalf("expected batch.Attempts() to return %v, got %v", 1, b.Attempts()) + } + + b.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) + if b.Attempts() != 3 { + t.Fatalf("expected batch.Attempts() to return %v, got %v", 3, b.Attempts()) } + // Test latency if b.Latency() != 0 { t.Fatalf("expected batch.Latency() to be 0, got %v", b.Latency()) } - b.totalLatency = 4 + b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1, TotalLatency: 4}}) if b.Latency() != 4 { t.Fatalf("expected batch.Latency() to return %v, got %v", 4, b.Latency()) } + // Test Consistency b.Cons = One if b.GetConsistency() != One { t.Fatalf("expected batch.GetConsistency() to return 'One', got '%s'", b.GetConsistency()) } + // Test batch.Query() b.Query("test", 1) if b.Entries[0].Stmt != "test" { - t.Fatalf("expected batch.Entries[0].Stmt to be 'test', got '%v'", b.Entries[0].Stmt) + t.Fatalf("expected batch.Entries[0].Statement to be 'test', got '%v'", b.Entries[0].Stmt) } else if b.Entries[0].Args[0].(int) != 1 { t.Fatalf("expected batch.Entries[0].Args[0] to be 1, got %v", b.Entries[0].Args[0]) } @@ -212,10 +247,12 @@ }) if b.Entries[1].Stmt != "test2" { - t.Fatalf("expected batch.Entries[1].Stmt to be 'test2', got '%v'", b.Entries[1].Stmt) + t.Fatalf("expected batch.Entries[1].Statement to be 'test2', got '%v'", b.Entries[1].Stmt) } else if b.Entries[1].binding == nil { t.Fatal("expected batch.Entries[1].binding to be defined, got nil") } + + // Test RetryPolicy r := &SimpleRetryPolicy{NumRetries: 4} b.RetryPolicy(r) @@ -250,3 +287,30 @@ } } } + +func TestIsUseStatement(t *testing.T) { + testCases := []struct { + input string + exp bool + }{ + {"USE foo", true}, + {"USe foo", true}, + {"UsE foo", true}, + {"Use foo", true}, + {"uSE foo", true}, + {"uSe foo", true}, + {"usE foo", true}, + {"use foo", true}, + {"SELECT ", false}, + {"UPDATE ", false}, + {"INSERT ", false}, + {"", false}, + } + + for _, tc := range testCases { + v := isUseStatement(tc.input) + if v != tc.exp { + t.Fatalf("expected %v but got %v for statement %q", tc.exp, v, tc.input) + } + } +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/stress_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/stress_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/stress_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/stress_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,4 +1,4 @@ -// +build all integration +// +build all cassandra package gocql diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/token.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/token.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/token.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/token.go 2019-11-02 13:15:23.000000000 +0000 @@ -39,7 +39,7 @@ func (p murmur3Partitioner) Hash(partitionKey []byte) token { h1 := murmur.Murmur3H1(partitionKey) - return murmur3Token(int64(h1)) + return murmur3Token(h1) } // murmur3 little-endian, 128-bit hash, but returns only h1 @@ -58,7 +58,7 @@ // order preserving partitioner and token type orderedPartitioner struct{} -type orderedToken []byte +type orderedToken string func (p orderedPartitioner) Name() string { return "OrderedPartitioner" @@ -70,15 +70,15 @@ } func (p orderedPartitioner) ParseString(str string) token { - return orderedToken([]byte(str)) + return orderedToken(str) } func (o orderedToken) String() string { - return string([]byte(o)) + return string(o) } func (o orderedToken) Less(token token) bool { - return -1 == bytes.Compare(o, token.(orderedToken)) + return o < token.(orderedToken) } // random partitioner and token @@ -118,17 +118,25 @@ return -1 == (*big.Int)(r).Cmp((*big.Int)(token.(*randomToken))) } +type hostToken struct { + token token + host *HostInfo +} + +func (ht hostToken) String() string { + return fmt.Sprintf("{token=%v host=%v}", ht.token, ht.host.HostID()) +} + // a data structure for organizing the relationship between tokens and hosts type tokenRing struct { partitioner partitioner - tokens []token + tokens []hostToken hosts []*HostInfo } func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) { tokenRing := &tokenRing{ - tokens: []token{}, - hosts: []*HostInfo{}, + hosts: hosts, } if strings.HasSuffix(partitioner, "Murmur3Partitioner") { @@ -144,8 +152,7 @@ for _, host := range hosts { for _, strToken := range host.Tokens() { token := tokenRing.partitioner.ParseString(strToken) - tokenRing.tokens = append(tokenRing.tokens, token) - tokenRing.hosts = append(tokenRing.hosts, host) + tokenRing.tokens = append(tokenRing.tokens, hostToken{token, host}) } } @@ -159,16 +166,14 @@ } func (t *tokenRing) Less(i, j int) bool { - return t.tokens[i].Less(t.tokens[j]) + return t.tokens[i].token.Less(t.tokens[j].token) } func (t *tokenRing) Swap(i, j int) { - t.tokens[i], t.hosts[i], t.tokens[j], t.hosts[j] = - t.tokens[j], t.hosts[j], t.tokens[i], t.hosts[i] + t.tokens[i], t.tokens[j] = t.tokens[j], t.tokens[i] } func (t *tokenRing) String() string { - buf := &bytes.Buffer{} buf.WriteString("TokenRing(") if t.partitioner != nil { @@ -176,52 +181,43 @@ } buf.WriteString("){") sep := "" - for i := range t.tokens { + for i, th := range t.tokens { buf.WriteString(sep) sep = "," buf.WriteString("\n\t[") buf.WriteString(strconv.Itoa(i)) buf.WriteString("]") - buf.WriteString(t.tokens[i].String()) + buf.WriteString(th.token.String()) buf.WriteString(":") - buf.WriteString(t.hosts[i].ConnectAddress().String()) + buf.WriteString(th.host.ConnectAddress().String()) } buf.WriteString("\n}") return string(buf.Bytes()) } -func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) *HostInfo { +func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) (host *HostInfo, endToken token) { if t == nil { - return nil + return nil, nil } - token := t.partitioner.Hash(partitionKey) - return t.GetHostForToken(token) + return t.GetHostForToken(t.partitioner.Hash(partitionKey)) } -func (t *tokenRing) GetHostForToken(token token) *HostInfo { - if t == nil { - return nil - } - - l := len(t.tokens) - // no host tokens, no available hosts - if l == 0 { - return nil +func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token) { + if t == nil || len(t.tokens) == 0 { + return nil, nil } // find the primary replica - ringIndex := sort.Search( - l, - func(i int) bool { - return !t.tokens[i].Less(token) - }, - ) + p := sort.Search(len(t.tokens), func(i int) bool { + return !t.tokens[i].token.Less(token) + }) - if ringIndex == l { + if p == len(t.tokens) { // wrap around to the first in the ring - ringIndex = 0 + p = 0 } - host := t.hosts[ringIndex] - return host + + v := t.tokens[p] + return v.host, v.token } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/token_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/token_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/token_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/token_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -132,18 +132,13 @@ type intToken int -func (i intToken) String() string { - return strconv.Itoa(int(i)) -} - -func (i intToken) Less(token token) bool { - return i < token.(intToken) -} +func (i intToken) String() string { return strconv.Itoa(int(i)) } +func (i intToken) Less(token token) bool { return i < token.(intToken) } // Test of the token ring implementation based on example at the start of this // page of documentation: // http://www.datastax.com/docs/0.8/cluster_architecture/partitioning -func TestIntTokenRing(t *testing.T) { +func TestTokenRing_Int(t *testing.T) { host0 := &HostInfo{} host25 := &HostInfo{} host50 := &HostInfo{} @@ -151,77 +146,71 @@ ring := &tokenRing{ partitioner: nil, // these tokens and hosts are out of order to test sorting - tokens: []token{ - intToken(0), - intToken(50), - intToken(75), - intToken(25), - }, - hosts: []*HostInfo{ - host0, - host50, - host75, - host25, + tokens: []hostToken{ + {intToken(0), host0}, + {intToken(50), host50}, + {intToken(75), host75}, + {intToken(25), host25}, }, } sort.Sort(ring) - if ring.GetHostForToken(intToken(0)) != host0 { + if host, endToken := ring.GetHostForToken(intToken(0)); host != host0 || endToken != intToken(0) { t.Error("Expected host 0 for token 0") } - if ring.GetHostForToken(intToken(1)) != host25 { + if host, endToken := ring.GetHostForToken(intToken(1)); host != host25 || endToken != intToken(25) { t.Error("Expected host 25 for token 1") } - if ring.GetHostForToken(intToken(24)) != host25 { + if host, endToken := ring.GetHostForToken(intToken(24)); host != host25 || endToken != intToken(25) { t.Error("Expected host 25 for token 24") } - if ring.GetHostForToken(intToken(25)) != host25 { + if host, endToken := ring.GetHostForToken(intToken(25)); host != host25 || endToken != intToken(25) { t.Error("Expected host 25 for token 25") } - if ring.GetHostForToken(intToken(26)) != host50 { + if host, endToken := ring.GetHostForToken(intToken(26)); host != host50 || endToken != intToken(50) { t.Error("Expected host 50 for token 26") } - if ring.GetHostForToken(intToken(49)) != host50 { + if host, endToken := ring.GetHostForToken(intToken(49)); host != host50 || endToken != intToken(50) { t.Error("Expected host 50 for token 49") } - if ring.GetHostForToken(intToken(50)) != host50 { + if host, endToken := ring.GetHostForToken(intToken(50)); host != host50 || endToken != intToken(50) { t.Error("Expected host 50 for token 50") } - if ring.GetHostForToken(intToken(51)) != host75 { + if host, endToken := ring.GetHostForToken(intToken(51)); host != host75 || endToken != intToken(75) { t.Error("Expected host 75 for token 51") } - if ring.GetHostForToken(intToken(74)) != host75 { + if host, endToken := ring.GetHostForToken(intToken(74)); host != host75 || endToken != intToken(75) { t.Error("Expected host 75 for token 74") } - if ring.GetHostForToken(intToken(75)) != host75 { + if host, endToken := ring.GetHostForToken(intToken(75)); host != host75 || endToken != intToken(75) { t.Error("Expected host 75 for token 75") } - if ring.GetHostForToken(intToken(76)) != host0 { + if host, endToken := ring.GetHostForToken(intToken(76)); host != host0 || endToken != intToken(0) { t.Error("Expected host 0 for token 76") } - if ring.GetHostForToken(intToken(99)) != host0 { + if host, endToken := ring.GetHostForToken(intToken(99)); host != host0 || endToken != intToken(0) { t.Error("Expected host 0 for token 99") } - if ring.GetHostForToken(intToken(100)) != host0 { + if host, endToken := ring.GetHostForToken(intToken(100)); host != host0 || endToken != intToken(0) { t.Error("Expected host 0 for token 100") } } // Test for the behavior of a nil pointer to tokenRing -func TestNilTokenRing(t *testing.T) { +func TestTokenRing_Nil(t *testing.T) { var ring *tokenRing = nil - if ring.GetHostForToken(nil) != nil { + if host, endToken := ring.GetHostForToken(nil); host != nil || endToken != nil { t.Error("Expected nil for nil token ring") } - if ring.GetHostForPartitionKey(nil) != nil { + if host, endToken := ring.GetHostForPartitionKey(nil); host != nil || endToken != nil { t.Error("Expected nil for nil token ring") } } // Test of the recognition of the partitioner class -func TestUnknownTokenRing(t *testing.T) { +func TestTokenRing_UnknownPartition(t *testing.T) { _, err := newTokenRing("UnknownPartitioner", nil) if err == nil { t.Error("Expected error for unknown partitioner value, but was nil") @@ -242,7 +231,7 @@ } // Test of the tokenRing with the Murmur3Partitioner -func TestMurmur3TokenRing(t *testing.T) { +func TestTokenRing_Murmur3(t *testing.T) { // Note, strings are parsed directly to int64, they are not murmur3 hashed hosts := hostsForTests(4) ring, err := newTokenRing("Murmur3Partitioner", hosts) @@ -253,26 +242,26 @@ p := murmur3Partitioner{} for _, host := range hosts { - actual := ring.GetHostForToken(p.ParseString(host.tokens[0])) + actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0])) if !actual.ConnectAddress().Equal(host.ConnectAddress()) { t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(), host.tokens[0], actual.ConnectAddress()) } } - actual := ring.GetHostForToken(p.ParseString("12")) + actual, _ := ring.GetHostForToken(p.ParseString("12")) if !actual.ConnectAddress().Equal(hosts[1].ConnectAddress()) { t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress()) } - actual = ring.GetHostForToken(p.ParseString("24324545443332")) + actual, _ = ring.GetHostForToken(p.ParseString("24324545443332")) if !actual.ConnectAddress().Equal(hosts[0].ConnectAddress()) { t.Errorf("Expected address 0 for token \"24324545443332\", but was %s", actual.ConnectAddress()) } } // Test of the tokenRing with the OrderedPartitioner -func TestOrderedTokenRing(t *testing.T) { +func TestTokenRing_Ordered(t *testing.T) { // Tokens here more or less are similar layout to the int tokens above due // to each numeric character translating to a consistently offset byte. hosts := hostsForTests(4) @@ -285,26 +274,26 @@ var actual *HostInfo for _, host := range hosts { - actual = ring.GetHostForToken(p.ParseString(host.tokens[0])) + actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0])) if !actual.ConnectAddress().Equal(host.ConnectAddress()) { t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(), host.tokens[0], actual.ConnectAddress()) } } - actual = ring.GetHostForToken(p.ParseString("12")) + actual, _ = ring.GetHostForToken(p.ParseString("12")) if !actual.peer.Equal(hosts[1].peer) { t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress()) } - actual = ring.GetHostForToken(p.ParseString("24324545443332")) + actual, _ = ring.GetHostForToken(p.ParseString("24324545443332")) if !actual.ConnectAddress().Equal(hosts[1].ConnectAddress()) { t.Errorf("Expected address 1 for token \"24324545443332\", but was %s", actual.ConnectAddress()) } } // Test of the tokenRing with the RandomPartitioner -func TestRandomTokenRing(t *testing.T) { +func TestTokenRing_Random(t *testing.T) { // String tokens are parsed into big.Int in base 10 hosts := hostsForTests(4) ring, err := newTokenRing("RandomPartitioner", hosts) @@ -316,19 +305,19 @@ var actual *HostInfo for _, host := range hosts { - actual = ring.GetHostForToken(p.ParseString(host.tokens[0])) + actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0])) if !actual.ConnectAddress().Equal(host.ConnectAddress()) { t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(), host.tokens[0], actual.ConnectAddress()) } } - actual = ring.GetHostForToken(p.ParseString("12")) + actual, _ = ring.GetHostForToken(p.ParseString("12")) if !actual.peer.Equal(hosts[1].peer) { t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress()) } - actual = ring.GetHostForToken(p.ParseString("24324545443332")) + actual, _ = ring.GetHostForToken(p.ParseString("24324545443332")) if !actual.ConnectAddress().Equal(hosts[0].ConnectAddress()) { t.Errorf("Expected address 1 for token \"24324545443332\", but was %s", actual.ConnectAddress()) } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/topology.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/topology.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/topology.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/topology.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,74 +1,276 @@ -// Copyright (c) 2012 The gocql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - package gocql import ( - "sync" - "sync/atomic" + "fmt" + "sort" + "strconv" + "strings" ) -type Node interface { - Pick(qry *Query) *Conn - Close() +type hostTokens struct { + token token + hosts []*HostInfo } -type RoundRobin struct { - pool []Node - pos uint32 - mu sync.RWMutex +type tokenRingReplicas []hostTokens + +func (h tokenRingReplicas) Less(i, j int) bool { return h[i].token.Less(h[j].token) } +func (h tokenRingReplicas) Len() int { return len(h) } +func (h tokenRingReplicas) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h tokenRingReplicas) replicasFor(t token) *hostTokens { + if len(h) == 0 { + return nil + } + + p := sort.Search(len(h), func(i int) bool { + return !h[i].token.Less(t) + }) + + // TODO: simplify this + if p < len(h) && h[p].token == t { + return &h[p] + } + + p-- + + if p >= len(h) { + // rollover + p = 0 + } else if p < 0 { + // rollunder + p = len(h) - 1 + } + + return &h[p] } -func NewRoundRobin() *RoundRobin { - return &RoundRobin{} +type placementStrategy interface { + replicaMap(tokenRing *tokenRing) tokenRingReplicas + replicationFactor(dc string) int } -func (r *RoundRobin) AddNode(node Node) { - r.mu.Lock() - r.pool = append(r.pool, node) - r.mu.Unlock() +func getReplicationFactorFromOpts(keyspace string, val interface{}) int { + // TODO: dont really want to panic here, but is better + // than spamming + switch v := val.(type) { + case int: + if v <= 0 { + panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", v, keyspace)) + } + return v + case string: + n, err := strconv.Atoi(v) + if err != nil { + panic(fmt.Sprintf("invalid replication_factor. Is the %q keyspace configured correctly? %v", keyspace, err)) + } else if n <= 0 { + panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", n, keyspace)) + } + return n + default: + panic(fmt.Sprintf("unkown replication_factor type %T", v)) + } } -func (r *RoundRobin) RemoveNode(node Node) { - r.mu.Lock() - n := len(r.pool) - for i := 0; i < n; i++ { - if r.pool[i] == node { - r.pool[i], r.pool[n-1] = r.pool[n-1], r.pool[i] - r.pool = r.pool[:n-1] - break +func getStrategy(ks *KeyspaceMetadata) placementStrategy { + switch { + case strings.Contains(ks.StrategyClass, "SimpleStrategy"): + return &simpleStrategy{rf: getReplicationFactorFromOpts(ks.Name, ks.StrategyOptions["replication_factor"])} + case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"): + dcs := make(map[string]int) + for dc, rf := range ks.StrategyOptions { + if dc == "class" { + continue + } + + dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf) } + return &networkTopology{dcs: dcs} + case strings.Contains(ks.StrategyClass, "LocalStrategy"): + return nil + default: + // TODO: handle unknown replicas and just return the primary host for a token + panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass)) } - r.mu.Unlock() } -func (r *RoundRobin) Size() int { - r.mu.RLock() - n := len(r.pool) - r.mu.RUnlock() - return n +type simpleStrategy struct { + rf int +} + +func (s *simpleStrategy) replicationFactor(dc string) int { + return s.rf } -func (r *RoundRobin) Pick(qry *Query) *Conn { - pos := atomic.AddUint32(&r.pos, 1) - var node Node - r.mu.RLock() - if len(r.pool) > 0 { - node = r.pool[pos%uint32(len(r.pool))] +func (s *simpleStrategy) replicaMap(tokenRing *tokenRing) tokenRingReplicas { + tokens := tokenRing.tokens + ring := make(tokenRingReplicas, len(tokens)) + + for i, th := range tokens { + replicas := make([]*HostInfo, 0, s.rf) + seen := make(map[*HostInfo]bool) + + for j := 0; j < len(tokens) && len(replicas) < s.rf; j++ { + h := tokens[(i+j)%len(tokens)] + if !seen[h.host] { + replicas = append(replicas, h.host) + seen[h.host] = true + } + } + + ring[i] = hostTokens{th.token, replicas} } - r.mu.RUnlock() - if node == nil { - return nil + + sort.Sort(ring) + + return ring +} + +type networkTopology struct { + dcs map[string]int +} + +func (n *networkTopology) replicationFactor(dc string) int { + return n.dcs[dc] +} + +func (n *networkTopology) haveRF(replicaCounts map[string]int) bool { + if len(replicaCounts) != len(n.dcs) { + return false } - return node.Pick(qry) + + for dc, rf := range n.dcs { + if rf != replicaCounts[dc] { + return false + } + } + + return true } -func (r *RoundRobin) Close() { - r.mu.Lock() - for i := 0; i < len(r.pool); i++ { - r.pool[i].Close() +func (n *networkTopology) replicaMap(tokenRing *tokenRing) tokenRingReplicas { + dcRacks := make(map[string]map[string]struct{}, len(n.dcs)) + // skipped hosts in a dc + skipped := make(map[string][]*HostInfo, len(n.dcs)) + // number of replicas per dc + replicasInDC := make(map[string]int, len(n.dcs)) + // dc -> racks + seenDCRacks := make(map[string]map[string]struct{}, len(n.dcs)) + + for _, h := range tokenRing.hosts { + dc := h.DataCenter() + rack := h.Rack() + + racks, ok := dcRacks[dc] + if !ok { + racks = make(map[string]struct{}) + dcRacks[dc] = racks + } + racks[rack] = struct{}{} + } + + for dc, racks := range dcRacks { + replicasInDC[dc] = 0 + seenDCRacks[dc] = make(map[string]struct{}, len(racks)) + } + + tokens := tokenRing.tokens + replicaRing := make(tokenRingReplicas, len(tokens)) + + var totalRF int + for _, rf := range n.dcs { + totalRF += rf } - r.pool = nil - r.mu.Unlock() + + for i, th := range tokenRing.tokens { + for k, v := range skipped { + skipped[k] = v[:0] + } + + for dc := range n.dcs { + replicasInDC[dc] = 0 + for rack := range seenDCRacks[dc] { + delete(seenDCRacks[dc], rack) + } + } + + replicas := make([]*HostInfo, 0, totalRF) + for j := 0; j < len(tokens) && (len(replicas) < totalRF && !n.haveRF(replicasInDC)); j++ { + // TODO: ensure we dont add the same host twice + p := i + j + if p >= len(tokens) { + p -= len(tokens) + } + h := tokens[p].host + + dc := h.DataCenter() + rack := h.Rack() + + rf, ok := n.dcs[dc] + if !ok { + // skip this DC, dont know about it + continue + } else if replicasInDC[dc] >= rf { + if replicasInDC[dc] > rf { + panic(fmt.Sprintf("replica overflow. rf=%d have=%d in dc %q", rf, replicasInDC[dc], dc)) + } + + // have enough replicas in this DC + continue + } else if _, ok := dcRacks[dc][rack]; !ok { + // dont know about this rack + continue + } + + racks := seenDCRacks[dc] + if _, ok := racks[rack]; ok && len(racks) == len(dcRacks[dc]) { + // we have been through all the racks and dont have RF yet, add this + replicas = append(replicas, h) + replicasInDC[dc]++ + } else if !ok { + if racks == nil { + racks = make(map[string]struct{}, 1) + seenDCRacks[dc] = racks + } + + // new rack + racks[rack] = struct{}{} + replicas = append(replicas, h) + r := replicasInDC[dc] + 1 + + if len(racks) == len(dcRacks[dc]) { + // if we have been through all the racks, drain the rest of the skipped + // hosts until we have RF. The next iteration will skip in the block + // above + skippedHosts := skipped[dc] + var k int + for ; k < len(skippedHosts) && r+k < rf; k++ { + sh := skippedHosts[k] + replicas = append(replicas, sh) + } + r += k + skipped[dc] = skippedHosts[k:] + } + replicasInDC[dc] = r + } else { + // already seen this rack, keep hold of this host incase + // we dont get enough for rf + skipped[dc] = append(skipped[dc], h) + } + } + + if len(replicas) == 0 { + panic(fmt.Sprintf("no replicas for token: %v", th.token)) + } else if !replicas[0].Equal(th.host) { + panic(fmt.Sprintf("first replica is not the primary replica for the token: expected %v got %v", replicas[0].ConnectAddress(), th.host.ConnectAddress())) + } + + replicaRing[i] = hostTokens{th.token, replicas} + } + + if len(replicaRing) != len(tokens) { + panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(replicaRing), len(tokens))) + } + + return replicaRing } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/topology_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/topology_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/topology_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/topology_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,51 +1,166 @@ -// +build all unit - package gocql import ( + "fmt" + "sort" "testing" ) -// fakeNode is used as a simple structure to test the RoundRobin API -type fakeNode struct { - conn *Conn - closed bool -} +func TestPlacementStrategy_SimpleStrategy(t *testing.T) { + host0 := &HostInfo{hostId: "0"} + host25 := &HostInfo{hostId: "25"} + host50 := &HostInfo{hostId: "50"} + host75 := &HostInfo{hostId: "75"} + + tokens := []hostToken{ + {intToken(0), host0}, + {intToken(25), host25}, + {intToken(50), host50}, + {intToken(75), host75}, + } -// Pick is needed to satisfy the Node interface -func (n *fakeNode) Pick(qry *Query) *Conn { - if n.conn == nil { - n.conn = &Conn{} + hosts := []*HostInfo{host0, host25, host50, host75} + + strat := &simpleStrategy{rf: 2} + tokenReplicas := strat.replicaMap(&tokenRing{hosts: hosts, tokens: tokens}) + if len(tokenReplicas) != len(tokens) { + t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas)) } - return n.conn -} -//Close is needed to satisfy the Node interface -func (n *fakeNode) Close() { - n.closed = true + for _, replicas := range tokenReplicas { + if len(replicas.hosts) != strat.rf { + t.Errorf("expected to have %d replicas got %d for token=%v", strat.rf, len(replicas.hosts), replicas.token) + } + } + + for i, token := range tokens { + ht := tokenReplicas.replicasFor(token.token) + if ht.token != token.token { + t.Errorf("token %v not in replica map: %v", token, ht.hosts) + } + + for j, replica := range ht.hosts { + exp := tokens[(i+j)%len(tokens)].host + if exp != replica { + t.Errorf("expected host %v to be a replica of %v got %v", exp.hostId, token, replica.hostId) + } + } + } } -//TestRoundRobinAPI tests the exported methods of the RoundRobin struct -//to make sure the API behaves accordingly. -func TestRoundRobinAPI(t *testing.T) { - node := &fakeNode{} - rr := NewRoundRobin() - rr.AddNode(node) +func TestPlacementStrategy_NetworkStrategy(t *testing.T) { + var ( + hosts []*HostInfo + tokens []hostToken + ) + + const ( + totalDCs = 3 + racksPerDC = 3 + hostsPerDC = 5 + ) + + dcRing := make(map[string][]hostToken, totalDCs) + for i := 0; i < totalDCs; i++ { + var dcTokens []hostToken + dc := fmt.Sprintf("dc%d", i+1) + + for j := 0; j < hostsPerDC; j++ { + rack := fmt.Sprintf("rack%d", (j%racksPerDC)+1) + + h := &HostInfo{hostId: fmt.Sprintf("%s:%s:%d", dc, rack, j), dataCenter: dc, rack: rack} + + token := hostToken{ + token: orderedToken([]byte(h.hostId)), + host: h, + } + + tokens = append(tokens, token) + dcTokens = append(dcTokens, token) - if rr.Size() != 1 { - t.Fatalf("expected size to be 1, got %v", rr.Size()) + hosts = append(hosts, h) + } + + sort.Sort(&tokenRing{tokens: dcTokens}) + dcRing[dc] = dcTokens + } + + if len(tokens) != hostsPerDC*totalDCs { + t.Fatalf("expected %d tokens in the ring got %d", hostsPerDC*totalDCs, len(tokens)) + } + sort.Sort(&tokenRing{tokens: tokens}) + + strat := &networkTopology{ + dcs: map[string]int{ + "dc1": 1, + "dc2": 2, + "dc3": 3, + }, + } + + var expReplicas int + for _, rf := range strat.dcs { + expReplicas += rf } - if c := rr.Pick(nil); c != node.conn { - t.Fatalf("expected conn %v, got %v", node.conn, c) + tokenReplicas := strat.replicaMap(&tokenRing{hosts: hosts, tokens: tokens}) + if len(tokenReplicas) != len(tokens) { + t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas)) + } + if !sort.IsSorted(tokenReplicas) { + t.Fatal("replica map was not sorted by token") } - rr.Close() - if rr.pool != nil { - t.Fatalf("expected rr.pool to be nil, got %v", rr.pool) + for token, replicas := range tokenReplicas { + if len(replicas.hosts) != expReplicas { + t.Fatalf("expected to have %d replicas got %d for token=%v", expReplicas, len(replicas.hosts), token) + } } - if !node.closed { - t.Fatal("expected node.closed to be true, got false") + for dc, rf := range strat.dcs { + dcTokens := dcRing[dc] + for i, th := range dcTokens { + token := th.token + allReplicas := tokenReplicas.replicasFor(token) + if allReplicas.token != token { + t.Fatalf("token %v not in replica map", token) + } + + var replicas []*HostInfo + for _, replica := range allReplicas.hosts { + if replica.dataCenter == dc { + replicas = append(replicas, replica) + } + } + + if len(replicas) != rf { + t.Fatalf("expected %d replicas in dc %q got %d", rf, dc, len(replicas)) + } + + var lastRack string + for j, replica := range replicas { + // expected is in the next rack + var exp *HostInfo + if lastRack == "" { + // primary, first replica + exp = dcTokens[(i+j)%len(dcTokens)].host + } else { + for k := 0; k < len(dcTokens); k++ { + // walk around the ring from i + j to find the next host the + // next rack + p := (i + j + k) % len(dcTokens) + h := dcTokens[p].host + if h.rack != lastRack { + exp = h + break + } + } + if exp.rack == lastRack { + panic("no more racks") + } + } + lastRack = replica.rack + } + } } } diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/.travis.yml golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/.travis.yml --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/.travis.yml 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/.travis.yml 2019-11-02 13:15:23.000000000 +0000 @@ -11,29 +11,33 @@ matrix: fast_finish: true +branches: + only: + - master + env: global: - GOMAXPROCS=2 matrix: - - CASS=2.1.12 - AUTH=false - - CASS=2.2.5 + - CASS=2.1.21 AUTH=true - - CASS=2.2.5 + - CASS=2.2.14 + AUTH=true + - CASS=2.2.14 + AUTH=false + - CASS=3.0.18 AUTH=false - - CASS=3.0.8 + - CASS=3.11.4 AUTH=false go: - - 1.8 - - 1.9 + - 1.12.x + - 1.13.x install: - - pip install --user cql PyYAML six - - git clone https://github.com/pcmanus/ccm.git - - pushd ccm - - ./setup.py install --user - - popd + - ./install_test_deps.sh $TRAVIS_REPO_SLUG + - cd ../.. + - cd gocql/gocql - go get . script: diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/udt_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/udt_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/udt_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/udt_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,4 +1,4 @@ -// +build all integration +// +build all cassandra package gocql diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/uuid.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/uuid.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/uuid.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/uuid.go 2019-11-02 13:15:23.000000000 +0000 @@ -112,29 +112,82 @@ var timeBase = time.Date(1582, time.October, 15, 0, 0, 0, 0, time.UTC).Unix() +// getTimestamp converts time to UUID (version 1) timestamp. +// It must be an interval of 100-nanoseconds since timeBase. +func getTimestamp(t time.Time) int64 { + utcTime := t.In(time.UTC) + ts := int64(utcTime.Unix()-timeBase)*10000000 + int64(utcTime.Nanosecond()/100) + + return ts +} + // TimeUUID generates a new time based UUID (version 1) using the current // time as the timestamp. func TimeUUID() UUID { return UUIDFromTime(time.Now()) } +// The min and max clock values for a UUID. +// +// Cassandra's TimeUUIDType compares the lsb parts as signed byte arrays. +// Thus, the min value for each byte is -128 and the max is +127. +const ( + minClock = 0x8080 + maxClock = 0x7f7f +) + +// The min and max node values for a UUID. +// +// See explanation about Cassandra's TimeUUIDType comparison logic above. +var ( + minNode = []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80} + maxNode = []byte{0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f} +) + +// MinTimeUUID generates a "fake" time based UUID (version 1) which will be +// the smallest possible UUID generated for the provided timestamp. +// +// UUIDs generated by this function are not unique and are mostly suitable only +// in queries to select a time range of a Cassandra's TimeUUID column. +func MinTimeUUID(t time.Time) UUID { + return TimeUUIDWith(getTimestamp(t), minClock, minNode) +} + +// MaxTimeUUID generates a "fake" time based UUID (version 1) which will be +// the biggest possible UUID generated for the provided timestamp. +// +// UUIDs generated by this function are not unique and are mostly suitable only +// in queries to select a time range of a Cassandra's TimeUUID column. +func MaxTimeUUID(t time.Time) UUID { + return TimeUUIDWith(getTimestamp(t), maxClock, maxNode) +} + // UUIDFromTime generates a new time based UUID (version 1) as described in // RFC 4122. This UUID contains the MAC address of the node that generated // the UUID, the given timestamp and a sequence number. -func UUIDFromTime(aTime time.Time) UUID { +func UUIDFromTime(t time.Time) UUID { + ts := getTimestamp(t) + clock := atomic.AddUint32(&clockSeq, 1) + + return TimeUUIDWith(ts, clock, hardwareAddr) +} + +// TimeUUIDWith generates a new time based UUID (version 1) as described in +// RFC4122 with given parameters. t is the number of 100's of nanoseconds +// since 15 Oct 1582 (60bits). clock is the number of clock sequence (14bits). +// node is a slice to gurarantee the uniqueness of the UUID (up to 6bytes). +// Note: calling this function does not increment the static clock sequence. +func TimeUUIDWith(t int64, clock uint32, node []byte) UUID { var u UUID - utcTime := aTime.In(time.UTC) - t := uint64(utcTime.Unix()-timeBase)*10000000 + uint64(utcTime.Nanosecond()/100) u[0], u[1], u[2], u[3] = byte(t>>24), byte(t>>16), byte(t>>8), byte(t) u[4], u[5] = byte(t>>40), byte(t>>32) u[6], u[7] = byte(t>>56)&0x0F, byte(t>>48) - clock := atomic.AddUint32(&clockSeq, 1) u[8] = byte(clock >> 8) u[9] = byte(clock) - copy(u[10:], hardwareAddr) + copy(u[10:], node) u[6] |= 0x10 // set version to 1 (time based uuid) u[8] &= 0x3F // clear variant @@ -198,6 +251,17 @@ return u[10:] } +// Clock extracts the clock sequence of this UUID. It will return zero if the +// UUID is not a time based UUID (version 1). +func (u UUID) Clock() uint32 { + if u.Version() != 1 { + return 0 + } + + // Clock sequence is the lower 14bits of u[8:10] + return uint32(u[8]&0x3F)<<8 | uint32(u[9]) +} + // Timestamp extracts the timestamp information from a time based UUID // (version 1). func (u UUID) Timestamp() int64 { diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/uuid_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/uuid_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/uuid_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/uuid_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -142,6 +142,30 @@ } } +func TestTimeUUIDWith(t *testing.T) { + utcTime := time.Date(1982, 5, 5, 12, 34, 56, 400, time.UTC) + ts := int64(utcTime.Unix()-timeBase)*10000000 + int64(utcTime.Nanosecond()/100) + clockSeq := uint32(0x3FFF) // Max number of clock sequence. + node := [7]byte{0, 1, 2, 3, 4, 5, 6} // The last element should be ignored. + uuid := TimeUUIDWith(ts, clockSeq, node[:]) + + if got := uuid.Variant(); got != VariantIETF { + t.Errorf("wrong variant. expected %d got %d", VariantIETF, got) + } + if got, want := uuid.Version(), 1; got != want { + t.Errorf("wrong version. Expected %v got %v", want, got) + } + if got := uuid.Timestamp(); got != int64(ts) { + t.Errorf("wrong timestamp. Expected %v got %v", ts, got) + } + if got := uuid.Clock(); uint32(got) != clockSeq { + t.Errorf("wrong clock. expected %v got %v", clockSeq, got) + } + if got, want := uuid.Node(), node[:6]; !bytes.Equal(got, want) { + t.Errorf("wrong node. expected %x, bot %x", want, got) + } +} + func TestParseUUID(t *testing.T) { uuid, _ := ParseUUID("486f3a88-775b-11e3-ae07-d231feb1dc81") if uuid.Time() != time.Date(2014, 1, 7, 5, 19, 29, 222516000, time.UTC) { @@ -216,3 +240,46 @@ t.Fatalf("uuids not equal after marshalling: before=%s after=%s", u, u2) } } + +func TestMinTimeUUID(t *testing.T) { + aTime := time.Now() + minTimeUUID := MinTimeUUID(aTime) + + ts := aTime.Unix() + tsFromUUID := minTimeUUID.Time().Unix() + if ts != tsFromUUID { + t.Errorf("timestamps are not equal: expected %d, got %d", ts, tsFromUUID) + } + + clockFromUUID := minTimeUUID.Clock() + // clear two most significant bits, as they are used for IETF variant + if minClock&0x3FFF != clockFromUUID { + t.Errorf("clocks are not equal: expected %08b, got %08b", minClock&0x3FFF, clockFromUUID) + } + + nodeFromUUID := minTimeUUID.Node() + if !bytes.Equal(minNode, nodeFromUUID) { + t.Errorf("nodes are not equal: expected %08b, got %08b", minNode, nodeFromUUID) + } +} + +func TestMaxTimeUUID(t *testing.T) { + aTime := time.Now() + maxTimeUUID := MaxTimeUUID(aTime) + + ts := aTime.Unix() + tsFromUUID := maxTimeUUID.Time().Unix() + if ts != tsFromUUID { + t.Errorf("timestamps are not equal: expected %d, got %d", ts, tsFromUUID) + } + + clockFromUUID := maxTimeUUID.Clock() + if maxClock&0x3FFF != clockFromUUID { + t.Errorf("clocks are not equal: expected %08b, got %08b", maxClock&0x3FFF, clockFromUUID) + } + + nodeFromUUID := maxTimeUUID.Node() + if !bytes.Equal(maxNode, nodeFromUUID) { + t.Errorf("nodes are not equal: expected %08b, got %08b", maxNode, nodeFromUUID) + } +} diff -Nru golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/wiki_test.go golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/wiki_test.go --- golang-github-gocql-gocql-0.0~git20171009.0.2416cf3/wiki_test.go 2017-10-13 09:25:13.000000000 +0000 +++ golang-github-gocql-gocql-0.0~git20191102.0.9faa4c0/wiki_test.go 2019-11-02 13:15:23.000000000 +0000 @@ -1,4 +1,4 @@ -// +build all integration +// +build all cassandra package gocql