diff --git a/README.md b/README.md index 4ae5adb..12ba58e 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ Just switch the import path from `cloud.google.com/go/datastore` to `github.com/ These features are unsupported just because we haven't found a use for the feature yet. PRs welcome: -* Embedded structs, nested slices, map types, some advanced query features (streaming aggregations, OR filters). +* Nested slices, map types, some advanced query features (streaming aggregations, OR filters). ## Testing diff --git a/pkg/datastore/batch_test.go b/pkg/datastore/batch_test.go new file mode 100644 index 0000000..9e50211 --- /dev/null +++ b/pkg/datastore/batch_test.go @@ -0,0 +1,94 @@ +package datastore_test + +import ( + "context" + "fmt" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestBatchOperations(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type Item struct { + ID int + } + + // Number of items > 1000 to test batching limits + // Put limit: 500, Get limit: 1000 + const count = 1200 + keys := make([]*datastore.Key, count) + items := make([]Item, count) + + for i := range count { + keys[i] = datastore.NameKey("Item", fmt.Sprintf("item-%d", i), nil) + items[i] = Item{ID: i} + } + + // Test PutMulti + if _, err := client.PutMulti(ctx, keys, items); err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + // Test GetMulti + results := make([]Item, count) + if err := client.GetMulti(ctx, keys, &results); err != nil { + t.Fatalf("GetMulti failed: %v", err) + } + + for i := range count { + if results[i].ID != i { + t.Errorf("Item %d mismatch: got %d, want %d", i, results[i].ID, i) + } + } + + // Test DeleteMulti + if err := client.DeleteMulti(ctx, keys); err != nil { + t.Fatalf("DeleteMulti failed: %v", err) + } + + // Verify deletion + err := client.GetMulti(ctx, keys, &results) + // Should return MultiError with all ErrNoSuchEntity + if err == nil { + t.Fatal("Expected error after deletion, got nil") + } +} + +func TestAllocateIDsBatch(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Test AllocateIDs > 500 + const count = 600 + keys := make([]*datastore.Key, count) + for i := range count { + keys[i] = datastore.IncompleteKey("Item", nil) + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs failed: %v", err) + } + + if len(allocated) != count { + t.Fatalf("Expected %d keys, got %d", count, len(allocated)) + } + + seen := make(map[int64]bool) + for i, k := range allocated { + if k.Incomplete() { + t.Errorf("Key %d is incomplete", i) + } + if seen[k.ID] { + t.Errorf("Duplicate ID %d at index %d", k.ID, i) + } + seen[k.ID] = true + } +} diff --git a/pkg/datastore/encode_coverage_test.go b/pkg/datastore/encode_coverage_test.go index 0292c1d..a897026 100644 --- a/pkg/datastore/encode_coverage_test.go +++ b/pkg/datastore/encode_coverage_test.go @@ -34,9 +34,9 @@ func TestEncodeValue_ReflectionSlices(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := encodeValue(tt.value) + result, err := encodeAny(tt.value) if err != nil { - t.Errorf("encodeValue(%v) failed: %v", tt.value, err) + t.Errorf("encodeAny(%v) failed: %v", tt.value, err) } if result == nil { t.Error("Expected non-nil result") @@ -63,17 +63,14 @@ func TestEncodeValue_Errors(t *testing.T) { "channel type", make(chan int), }, - { - "struct type", - struct{ Name string }{Name: "test"}, - }, + // Note: struct types are now supported as nested entities } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := encodeValue(tt.value) + _, err := encodeAny(tt.value) if err == nil { - t.Errorf("encodeValue(%T) should have returned an error", tt.value) + t.Errorf("encodeAny(%T) should have returned an error", tt.value) } }) } @@ -86,7 +83,7 @@ func TestEncodeValue_TimeSlice(t *testing.T) { timeSlice := []time.Time{now, later} - result, err := encodeValue(timeSlice) + result, err := encodeAny(timeSlice) if err != nil { t.Fatalf("encodeValue failed for time slice: %v", err) } @@ -122,7 +119,7 @@ func TestEncodeValue_EmptySlices(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := encodeValue(tt.value) + result, err := encodeAny(tt.value) if err != nil { t.Errorf("encodeValue failed: %v", err) } @@ -148,7 +145,7 @@ func TestEncodeValue_SingleElementSlices(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := encodeValue(tt.value) + result, err := encodeAny(tt.value) if err != nil { t.Errorf("encodeValue failed: %v", err) } diff --git a/pkg/datastore/entity.go b/pkg/datastore/entity.go index edda13d..3537214 100644 --- a/pkg/datastore/entity.go +++ b/pkg/datastore/entity.go @@ -1,313 +1,10 @@ package datastore -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" - "time" -) - -// encodeEntity converts a Go struct to a Datastore entity. -func encodeEntity(key *Key, src any) (map[string]any, error) { - v := reflect.ValueOf(src) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - - if v.Kind() != reflect.Struct { - return nil, errors.New("src must be a struct or pointer to struct") - } - - t := v.Type() - properties := make(map[string]any) - - for i := range v.NumField() { - field := t.Field(i) - value := v.Field(i) - - // Skip unexported fields - if !field.IsExported() { - continue - } - - // Get field name from datastore tag or use field name - name := field.Name - noIndex := false - - if tag := field.Tag.Get("datastore"); tag != "" { - parts := strings.Split(tag, ",") - if parts[0] != "" && parts[0] != "-" { - name = parts[0] - } - if len(parts) > 1 && parts[1] == "noindex" { - noIndex = true - } - if parts[0] == "-" { - continue - } - } - - prop, err := encodeValue(value.Interface()) - if err != nil { - return nil, fmt.Errorf("field %s: %w", field.Name, err) - } - - if noIndex { - if m, ok := prop.(map[string]any); ok { - m["excludeFromIndexes"] = true - } - } - - properties[name] = prop - } - - return map[string]any{ - "key": keyToJSON(key), - "properties": properties, - }, nil -} - -// encodeValue converts a Go value to a Datastore property value. -func encodeValue(v any) (any, error) { - if v == nil { - return map[string]any{"nullValue": nil}, nil - } - - switch val := v.(type) { - case string: - return map[string]any{"stringValue": val}, nil - case int: - return map[string]any{"integerValue": strconv.Itoa(val)}, nil - case int64: - return map[string]any{"integerValue": strconv.FormatInt(val, 10)}, nil - case int32: - return map[string]any{"integerValue": strconv.Itoa(int(val))}, nil - case bool: - return map[string]any{"booleanValue": val}, nil - case float64: - return map[string]any{"doubleValue": val}, nil - case time.Time: - return map[string]any{"timestampValue": val.Format(time.RFC3339Nano)}, nil - case []string: - values := make([]map[string]any, len(val)) - for i, s := range val { - values[i] = map[string]any{"stringValue": s} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - case []int64: - values := make([]map[string]any, len(val)) - for i, n := range val { - values[i] = map[string]any{"integerValue": strconv.FormatInt(n, 10)} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - case []int: - values := make([]map[string]any, len(val)) - for i, n := range val { - values[i] = map[string]any{"integerValue": strconv.Itoa(n)} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - case []float64: - values := make([]map[string]any, len(val)) - for i, f := range val { - values[i] = map[string]any{"doubleValue": f} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - case []bool: - values := make([]map[string]any, len(val)) - for i, b := range val { - values[i] = map[string]any{"booleanValue": b} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - default: - // Try to handle slices/arrays via reflection - rv := reflect.ValueOf(v) - if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { - length := rv.Len() - values := make([]map[string]any, length) - for i := range length { - elem := rv.Index(i).Interface() - encodedElem, err := encodeValue(elem) - if err != nil { - return nil, fmt.Errorf("failed to encode array element %d: %w", i, err) - } - // encodedElem is already a map[string]any with the type wrapper - m, ok := encodedElem.(map[string]any) - if !ok { - return nil, fmt.Errorf("unexpected encoded value type for element %d", i) - } - values[i] = m - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - } - return nil, fmt.Errorf("unsupported type: %T", v) - } -} - -// decodeEntity converts a Datastore entity to a Go struct. -func decodeEntity(entity map[string]any, dst any) error { - v := reflect.ValueOf(dst) - if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { - return errors.New("dst must be a pointer to struct") - } - - v = v.Elem() - t := v.Type() - - properties, ok := entity["properties"].(map[string]any) - if !ok { - return errors.New("invalid entity format") - } - - for i := range v.NumField() { - field := t.Field(i) - value := v.Field(i) - - if !field.IsExported() { - continue - } - - // Get field name from datastore tag - name := field.Name - if tag := field.Tag.Get("datastore"); tag != "" { - parts := strings.Split(tag, ",") - if parts[0] != "" && parts[0] != "-" { - name = parts[0] - } - if parts[0] == "-" { - continue - } - } +import "errors" - prop, ok := properties[name] - if !ok { - continue // Field not in entity - } - - propMap, ok := prop.(map[string]any) - if !ok { - continue - } - - if err := decodeValue(propMap, value); err != nil { - return fmt.Errorf("field %s: %w", field.Name, err) - } - } - - return nil -} - -// decodeValue decodes a Datastore property value into a Go reflect.Value. -func decodeValue(prop map[string]any, dst reflect.Value) error { - // Handle each type - if val, ok := prop["stringValue"]; ok { - if dst.Kind() == reflect.String { - if s, ok := val.(string); ok { - dst.SetString(s) - return nil - } - } - } - - if val, ok := prop["integerValue"]; ok { - var intVal int64 - switch v := val.(type) { - case string: - if _, err := fmt.Sscanf(v, "%d", &intVal); err != nil { - return fmt.Errorf("invalid integer format: %w", err) - } - case float64: - intVal = int64(v) - } - - switch dst.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32: - dst.SetInt(intVal) - return nil - default: - return fmt.Errorf("unsupported integer type: %v", dst.Kind()) - } - } - - if val, ok := prop["booleanValue"]; ok { - if dst.Kind() == reflect.Bool { - if b, ok := val.(bool); ok { - dst.SetBool(b) - return nil - } - } - } - - if val, ok := prop["doubleValue"]; ok { - if dst.Kind() == reflect.Float64 { - if f, ok := val.(float64); ok { - dst.SetFloat(f) - return nil - } - } - } - - if val, ok := prop["timestampValue"]; ok { - if dst.Type() == reflect.TypeOf(time.Time{}) { - if s, ok := val.(string); ok { - t, err := time.Parse(time.RFC3339Nano, s) - if err != nil { - return err - } - dst.Set(reflect.ValueOf(t)) - return nil - } - } - } - - if val, ok := prop["arrayValue"]; ok { - if dst.Kind() != reflect.Slice { - return fmt.Errorf("cannot decode array into non-slice type: %s", dst.Type()) - } - - arrayMap, ok := val.(map[string]any) - if !ok { - return errors.New("invalid arrayValue format") - } - - valuesAny, ok := arrayMap["values"] - if !ok { - // Empty array - dst.Set(reflect.MakeSlice(dst.Type(), 0, 0)) - return nil - } - - values, ok := valuesAny.([]any) - if !ok { - return errors.New("invalid arrayValue.values format") - } - - // Create slice with appropriate capacity - slice := reflect.MakeSlice(dst.Type(), len(values), len(values)) - - // Decode each element - for i, elemAny := range values { - elemMap, ok := elemAny.(map[string]any) - if !ok { - return fmt.Errorf("invalid array element %d format", i) - } - - elemValue := slice.Index(i) - if err := decodeValue(elemMap, elemValue); err != nil { - return fmt.Errorf("failed to decode array element %d: %w", i, err) - } - } - - dst.Set(slice) - return nil - } - - if _, ok := prop["nullValue"]; ok { - // Set to zero value - dst.Set(reflect.Zero(dst.Type())) - return nil - } - - return fmt.Errorf("unsupported property type for %s", dst.Type()) -} +// Entity encoding/decoding errors. +var ( + errNotStruct = errors.New("src must be a struct or pointer to struct") + errNotStructPtr = errors.New("dst must be a pointer to struct") + errInvalidEntity = errors.New("invalid entity format") +) diff --git a/pkg/datastore/entity_coverage_test.go b/pkg/datastore/entity_coverage_test.go index aeaac23..75b97d8 100644 --- a/pkg/datastore/entity_coverage_test.go +++ b/pkg/datastore/entity_coverage_test.go @@ -33,7 +33,7 @@ func TestEncodeValue_AllTypes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := encodeValue(tt.value) + _, err := encodeAny(tt.value) if err != nil { t.Errorf("encodeValue failed: %v", err) } @@ -42,7 +42,7 @@ func TestEncodeValue_AllTypes(t *testing.T) { } func TestEncodeValue_UnsupportedType(t *testing.T) { - _, err := encodeValue(map[string]string{"key": "value"}) + _, err := encodeAny(map[string]string{"key": "value"}) if err == nil { t.Error("Expected error for unsupported type, got nil") } diff --git a/pkg/datastore/entity_decode.go b/pkg/datastore/entity_decode.go new file mode 100644 index 0000000..c2ec6ec --- /dev/null +++ b/pkg/datastore/entity_decode.go @@ -0,0 +1,394 @@ +package datastore + +import ( + "encoding/base64" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +// decodeEntity converts a Datastore entity to a Go struct. +// It also populates any field tagged with `datastore:"__key__"` with the entity's key. +func decodeEntity(entity map[string]any, dst any) error { + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return errNotStructPtr + } + + properties, ok := entity["properties"].(map[string]any) + if !ok { + return errInvalidEntity + } + + // Extract key if present + var key *Key + if keyData, ok := entity["key"]; ok { + var err error + key, err = keyFromJSON(keyData) + if err != nil { + // Non-fatal: continue without key + key = nil + } + } + + return decodeStruct(properties, v.Elem(), key, "") +} + +// decodeStruct decodes Datastore properties into a struct. +// key is the entity key (for __key__ field population). +// prefix is used for flattened fields (e.g., "Address."). +func decodeStruct(properties map[string]any, v reflect.Value, key *Key, prefix string) error { + t := v.Type() + + for i := range v.NumField() { + field := t.Field(i) + fieldVal := v.Field(i) + + if !field.IsExported() { + continue + } + + opts := parseDecodeTag(field) + if opts.skip { + continue + } + + // Handle __key__ field + if opts.name == "__key__" && key != nil { + if fieldVal.Type() == reflect.TypeOf((*Key)(nil)) { + fieldVal.Set(reflect.ValueOf(key)) + } + continue + } + + // Handle embedded (anonymous) structs + if field.Anonymous && fieldVal.Kind() == reflect.Struct { + if err := decodeStruct(properties, fieldVal, key, prefix); err != nil { + return fmt.Errorf("embedded %s: %w", field.Name, err) + } + continue + } + + propName := prefix + opts.name + + // Handle flatten for struct fields + if opts.flatten && isDecodableStruct(fieldVal) { + sv := fieldVal + if sv.Kind() == reflect.Ptr { + // Allocate if nil + if sv.IsNil() { + sv.Set(reflect.New(sv.Type().Elem())) + } + sv = sv.Elem() + } + if err := decodeStruct(properties, sv, key, propName+"."); err != nil { + return fmt.Errorf("field %s: %w", field.Name, err) + } + continue + } + + prop, ok := properties[propName] + if !ok { + continue + } + + propMap, ok := prop.(map[string]any) + if !ok { + continue + } + + if err := decodeValue(propMap, fieldVal); err != nil { + return fmt.Errorf("field %s: %w", field.Name, err) + } + } + + return nil +} + +// decodeTagOptions holds parsed decode tag options. +type decodeTagOptions struct { + name string + flatten bool + skip bool +} + +// parseDecodeTag extracts field name and options from datastore tag for decoding. +func parseDecodeTag(field reflect.StructField) decodeTagOptions { + opts := decodeTagOptions{name: field.Name} + + tag := field.Tag.Get("datastore") + if tag == "" { + return opts + } + + parts := strings.Split(tag, ",") + if parts[0] == "-" { + opts.skip = true + return opts + } + if parts[0] != "" { + opts.name = parts[0] + } + + for _, opt := range parts[1:] { + if opt == "flatten" { + opts.flatten = true + } + } + + return opts +} + +// isDecodableStruct reports whether v is a struct or pointer to struct (excluding time.Time). +func isDecodableStruct(v reflect.Value) bool { + if v.Kind() == reflect.Struct { + return v.Type() != reflect.TypeOf(time.Time{}) + } + if v.Kind() == reflect.Ptr { + elem := v.Type().Elem() + return elem.Kind() == reflect.Struct && elem != reflect.TypeOf(time.Time{}) + } + return false +} + +// decodeValue decodes a Datastore property value into a Go reflect.Value. +func decodeValue(prop map[string]any, dst reflect.Value) error { + // Handle pointer destinations + if dst.Kind() == reflect.Ptr { + // Check for null + if _, ok := prop["nullValue"]; ok { + dst.Set(reflect.Zero(dst.Type())) + return nil + } + // Allocate if nil + if dst.IsNil() { + dst.Set(reflect.New(dst.Type().Elem())) + } + return decodeValue(prop, dst.Elem()) + } + + // Handle null for non-pointers + if _, ok := prop["nullValue"]; ok { + dst.Set(reflect.Zero(dst.Type())) + return nil + } + + // String + if val, ok := prop["stringValue"]; ok { + return decodeString(val, dst) + } + + // Integer + if val, ok := prop["integerValue"]; ok { + return decodeInteger(val, dst) + } + + // Boolean + if val, ok := prop["booleanValue"]; ok { + return decodeBool(val, dst) + } + + // Double/Float + if val, ok := prop["doubleValue"]; ok { + return decodeDouble(val, dst) + } + + // Timestamp + if val, ok := prop["timestampValue"]; ok { + return decodeTimestamp(val, dst) + } + + // Blob + if val, ok := prop["blobValue"]; ok { + return decodeBlob(val, dst) + } + + // Array + if val, ok := prop["arrayValue"]; ok { + return decodeArray(val, dst) + } + + // Entity (nested struct) + if val, ok := prop["entityValue"]; ok { + return decodeEntityValue(val, dst) + } + + // Key reference + if val, ok := prop["keyValue"]; ok { + return decodeKeyValue(val, dst) + } + + return fmt.Errorf("unsupported property type for %s", dst.Type()) +} + +func decodeString(val any, dst reflect.Value) error { + s, ok := val.(string) + if !ok { + return errors.New("invalid string value") + } + if dst.Kind() != reflect.String { + return fmt.Errorf("cannot decode string into %s", dst.Type()) + } + dst.SetString(s) + return nil +} + +func decodeInteger(val any, dst reflect.Value) error { + var intVal int64 + + switch v := val.(type) { + case string: + var err error + intVal, err = strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf("invalid integer: %w", err) + } + case float64: + intVal = int64(v) + default: + return fmt.Errorf("unexpected integer format: %T", val) + } + + switch dst.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.SetInt(intVal) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if intVal < 0 { + return fmt.Errorf("cannot decode negative value %d into unsigned type", intVal) + } + dst.SetUint(uint64(intVal)) + default: + return fmt.Errorf("cannot decode integer into %s", dst.Type()) + } + return nil +} + +func decodeBool(val any, dst reflect.Value) error { + b, ok := val.(bool) + if !ok { + return errors.New("invalid boolean value") + } + if dst.Kind() != reflect.Bool { + return fmt.Errorf("cannot decode bool into %s", dst.Type()) + } + dst.SetBool(b) + return nil +} + +func decodeDouble(val any, dst reflect.Value) error { + f, ok := val.(float64) + if !ok { + return errors.New("invalid double value") + } + switch dst.Kind() { + case reflect.Float32, reflect.Float64: + dst.SetFloat(f) + default: + return fmt.Errorf("cannot decode double into %s", dst.Type()) + } + return nil +} + +func decodeTimestamp(val any, dst reflect.Value) error { + s, ok := val.(string) + if !ok { + return errors.New("invalid timestamp value") + } + if dst.Type() != reflect.TypeOf(time.Time{}) { + return fmt.Errorf("cannot decode timestamp into %s", dst.Type()) + } + t, err := time.Parse(time.RFC3339Nano, s) + if err != nil { + return fmt.Errorf("invalid timestamp format: %w", err) + } + dst.Set(reflect.ValueOf(t)) + return nil +} + +func decodeBlob(val any, dst reflect.Value) error { + s, ok := val.(string) + if !ok { + return errors.New("invalid blob value") + } + if dst.Type() != reflect.TypeOf([]byte(nil)) { + return fmt.Errorf("cannot decode blob into %s", dst.Type()) + } + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return fmt.Errorf("invalid base64: %w", err) + } + dst.SetBytes(data) + return nil +} + +func decodeArray(val any, dst reflect.Value) error { + if dst.Kind() != reflect.Slice { + return fmt.Errorf("cannot decode array into %s", dst.Type()) + } + + arrayMap, ok := val.(map[string]any) + if !ok { + return errors.New("invalid arrayValue format") + } + + valuesAny, ok := arrayMap["values"] + if !ok { + dst.Set(reflect.MakeSlice(dst.Type(), 0, 0)) + return nil + } + + values, ok := valuesAny.([]any) + if !ok { + return errors.New("invalid arrayValue.values format") + } + + slice := reflect.MakeSlice(dst.Type(), len(values), len(values)) + + for i, elemAny := range values { + elemMap, ok := elemAny.(map[string]any) + if !ok { + return fmt.Errorf("invalid array element %d", i) + } + if err := decodeValue(elemMap, slice.Index(i)); err != nil { + return fmt.Errorf("element %d: %w", i, err) + } + } + + dst.Set(slice) + return nil +} + +func decodeEntityValue(val any, dst reflect.Value) error { + entityMap, ok := val.(map[string]any) + if !ok { + return errors.New("invalid entityValue format") + } + + properties, ok := entityMap["properties"].(map[string]any) + if !ok { + return errors.New("invalid entityValue.properties format") + } + + if dst.Kind() != reflect.Struct { + return fmt.Errorf("cannot decode entity into %s", dst.Type()) + } + + return decodeStruct(properties, dst, nil, "") +} + +func decodeKeyValue(val any, dst reflect.Value) error { + if dst.Type() != reflect.TypeOf((*Key)(nil)) { + return fmt.Errorf("cannot decode key into %s", dst.Type()) + } + + key, err := keyFromJSON(val) + if err != nil { + return fmt.Errorf("invalid key: %w", err) + } + + dst.Set(reflect.ValueOf(key)) + return nil +} diff --git a/pkg/datastore/entity_encode.go b/pkg/datastore/entity_encode.go new file mode 100644 index 0000000..344c9d6 --- /dev/null +++ b/pkg/datastore/entity_encode.go @@ -0,0 +1,294 @@ +package datastore + +import ( + "encoding/base64" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +// tagOptions holds parsed struct field tag options. +type tagOptions struct { + name string + noIndex bool + omitempty bool + flatten bool + skip bool +} + +// encodeEntity converts a Go struct to a Datastore entity. +func encodeEntity(key *Key, src any) (map[string]any, error) { + v := reflect.ValueOf(src) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return nil, errNotStruct + } + + properties, err := encodeStruct(v, "") + if err != nil { + return nil, err + } + + return map[string]any{ + "key": keyToJSON(key), + "properties": properties, + }, nil +} + +// encodeStruct encodes a struct value to Datastore properties. +// prefix is used for flattened nested structs (e.g., "Address."). +func encodeStruct(v reflect.Value, prefix string) (map[string]any, error) { + t := v.Type() + properties := make(map[string]any) + + for i := range v.NumField() { + field := t.Field(i) + fieldVal := v.Field(i) + + if !field.IsExported() { + continue + } + + opts := parseTag(field) + if opts.skip { + continue + } + + // Handle embedded (anonymous) structs + if field.Anonymous && fieldVal.Kind() == reflect.Struct { + embedded, err := encodeStruct(fieldVal, prefix) + if err != nil { + return nil, fmt.Errorf("embedded %s: %w", field.Name, err) + } + for k, v := range embedded { + properties[k] = v + } + continue + } + + // Check omitempty before encoding + if opts.omitempty && isEmpty(fieldVal) { + continue + } + + propName := prefix + opts.name + + // Handle flatten for struct fields + if opts.flatten && isStructOrStructPtr(fieldVal) { + sv := fieldVal + if sv.Kind() == reflect.Ptr { + if sv.IsNil() { + continue // Skip nil pointers when flattening + } + sv = sv.Elem() + } + flattened, err := encodeStruct(sv, propName+".") + if err != nil { + return nil, fmt.Errorf("field %s: %w", field.Name, err) + } + for k, v := range flattened { + properties[k] = v + } + continue + } + + prop, err := encodeValue(fieldVal) + if err != nil { + return nil, fmt.Errorf("field %s: %w", field.Name, err) + } + + if opts.noIndex { + if m, ok := prop.(map[string]any); ok { + m["excludeFromIndexes"] = true + } + } + + properties[propName] = prop + } + + return properties, nil +} + +// parseTag extracts field name and options from datastore tag. +func parseTag(field reflect.StructField) tagOptions { + opts := tagOptions{name: field.Name} + + tag := field.Tag.Get("datastore") + if tag == "" { + return opts + } + + parts := strings.Split(tag, ",") + if parts[0] == "-" { + opts.skip = true + return opts + } + if parts[0] != "" { + opts.name = parts[0] + } + + for _, opt := range parts[1:] { + switch opt { + case "noindex": + opts.noIndex = true + case "omitempty": + opts.omitempty = true + case "flatten": + opts.flatten = true + default: + // Ignore unknown options + } + } + + return opts +} + +// isEmpty reports whether v is the zero value for its type. +func isEmpty(v reflect.Value) bool { + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + return v.IsNil() + case reflect.Slice, reflect.Map: + return v.IsNil() || v.Len() == 0 + case reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Struct: + // Special case for time.Time + if t, ok := v.Interface().(time.Time); ok { + return t.IsZero() + } + return false + default: + return false + } +} + +// isStructOrStructPtr reports whether v is a struct or pointer to struct. +func isStructOrStructPtr(v reflect.Value) bool { + if v.Kind() == reflect.Struct { + return true + } + if v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + return true + } + return false +} + +// encodeAny converts any Go value to a Datastore property value. +func encodeAny(v any) (any, error) { + if v == nil { + return map[string]any{"nullValue": nil}, nil + } + return encodeValue(reflect.ValueOf(v)) +} + +// encodeValue converts a Go reflect.Value to a Datastore property value. +func encodeValue(v reflect.Value) (any, error) { + // Handle pointers - dereference or return null + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return map[string]any{"nullValue": nil}, nil + } + v = v.Elem() + } + + // Handle interface{} - get underlying value + if v.Kind() == reflect.Interface { + if v.IsNil() { + return map[string]any{"nullValue": nil}, nil + } + v = v.Elem() + } + + // Check for specific types first (before kind switch) + switch val := v.Interface().(type) { + case time.Time: + return map[string]any{"timestampValue": val.Format(time.RFC3339Nano)}, nil + case *Key: + if val == nil { + return map[string]any{"nullValue": nil}, nil + } + return map[string]any{"keyValue": keyToJSON(val)}, nil + } + + // Handle by kind + switch v.Kind() { + case reflect.String: + return map[string]any{"stringValue": v.String()}, nil + + case reflect.Bool: + return map[string]any{"booleanValue": v.Bool()}, nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return map[string]any{"integerValue": strconv.FormatInt(v.Int(), 10)}, nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return map[string]any{"integerValue": strconv.FormatUint(v.Uint(), 10)}, nil + + case reflect.Float32, reflect.Float64: + return map[string]any{"doubleValue": v.Float()}, nil + + case reflect.Slice, reflect.Array: + return encodeSlice(v) + + case reflect.Struct: + return encodeNestedStruct(v) + + default: + return nil, fmt.Errorf("unsupported type: %s", v.Type()) + } +} + +// encodeSlice encodes a slice or array to a Datastore array value. +func encodeSlice(v reflect.Value) (any, error) { + // Special case: []byte becomes blobValue (only for slices, not arrays) + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + data := v.Bytes() + return map[string]any{"blobValue": base64.StdEncoding.EncodeToString(data)}, nil + } + + length := v.Len() + values := make([]map[string]any, length) + + for i := range length { + elem := v.Index(i) + encoded, err := encodeValue(elem) + if err != nil { + return nil, fmt.Errorf("element %d: %w", i, err) + } + m, ok := encoded.(map[string]any) + if !ok { + return nil, fmt.Errorf("unexpected encoded type for element %d", i) + } + values[i] = m + } + + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil +} + +// encodeNestedStruct encodes a nested struct as an entity value. +func encodeNestedStruct(v reflect.Value) (any, error) { + properties, err := encodeStruct(v, "") + if err != nil { + return nil, err + } + + return map[string]any{ + "entityValue": map[string]any{ + "properties": properties, + }, + }, nil +} diff --git a/pkg/datastore/flush_pattern_test.go b/pkg/datastore/flush_pattern_test.go new file mode 100644 index 0000000..3837177 --- /dev/null +++ b/pkg/datastore/flush_pattern_test.go @@ -0,0 +1,72 @@ +package datastore_test + +import ( + "context" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +// TestFlushPattern verifies the pattern used in sfcache's Flush method. +// It simulates GetAll with a KeysOnly query and a dummy destination. +func TestFlushPattern(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + kind := "CacheEntry" + + // 1. Put some entries + keys := []*datastore.Key{ + datastore.NameKey(kind, "key1", nil), + datastore.NameKey(kind, "key2", nil), + } + + type Entry struct { + Value string + } + + entries := []Entry{ + {Value: "val1"}, + {Value: "val2"}, + } + + if _, err := client.PutMulti(ctx, keys, entries); err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + // 2. Simulate Flush: Query KeysOnly + query := datastore.NewQuery(kind).KeysOnly() + var dst []Entry // Dummy destination + + // 3. GetAll with dst + gotKeys, err := client.GetAll(ctx, query, &dst) + if err != nil { + t.Fatalf("GetAll failed with KeysOnly: %v", err) + } + + // 4. Verify keys returned + if len(gotKeys) != 2 { + t.Errorf("Expected 2 keys, got %d", len(gotKeys)) + } + + // 5. Verify dst was NOT populated (optional, but expected behavior) + if len(dst) != 0 { + t.Errorf("Expected dst to remain empty, got %d items", len(dst)) + } + + // 6. DeleteMulti using returned keys + if err := client.DeleteMulti(ctx, gotKeys); err != nil { + t.Fatalf("DeleteMulti failed: %v", err) + } + + // 7. Verify deletion + qCheck := datastore.NewQuery(kind).KeysOnly() + remaining, err := client.GetAll(ctx, qCheck, nil) + if err != nil { + t.Fatalf("Verification query failed: %v", err) + } + if len(remaining) != 0 { + t.Errorf("Expected 0 remaining items, got %d", len(remaining)) + } +} diff --git a/pkg/datastore/iterator.go b/pkg/datastore/iterator.go index a79424f..2d8c55a 100644 --- a/pkg/datastore/iterator.go +++ b/pkg/datastore/iterator.go @@ -63,9 +63,12 @@ func (it *Iterator) Next(dst any) (*Key, error) { it.cursor = result.cursor } - // Decode entity into dst - if err := decodeEntity(result.entity, dst); err != nil { - return nil, err + // For KeysOnly queries, skip entity decoding - just return the key + // The Datastore API returns entities without properties for keys-only queries + if !it.query.keysOnly { + if err := decodeEntity(result.entity, dst); err != nil { + return nil, err + } } return result.key, nil diff --git a/pkg/datastore/key.go b/pkg/datastore/key.go index 844aa35..2936d49 100644 --- a/pkg/datastore/key.go +++ b/pkg/datastore/key.go @@ -11,42 +11,55 @@ import ( // Key represents a Datastore key. type Key struct { - Parent *Key // Parent key for hierarchical keys - Kind string - Name string // For string keys - ID int64 // For numeric keys + Namespace string + Parent *Key // Parent key for hierarchical keys + Kind string + Name string // For string keys + ID int64 // For numeric keys } // NameKey creates a new key with a string name. // The parent parameter can be nil for top-level keys. // This matches the API of cloud.google.com/go/datastore. func NameKey(kind, name string, parent *Key) *Key { - return &Key{ + k := &Key{ Kind: kind, Name: name, Parent: parent, } + if parent != nil { + k.Namespace = parent.Namespace + } + return k } // IDKey creates a new key with a numeric ID. // The parent parameter can be nil for top-level keys. // This matches the API of cloud.google.com/go/datastore. func IDKey(kind string, id int64, parent *Key) *Key { - return &Key{ + k := &Key{ Kind: kind, ID: id, Parent: parent, } + if parent != nil { + k.Namespace = parent.Namespace + } + return k } // IncompleteKey creates a new incomplete key. // The key will be completed (assigned an ID) when the entity is saved. // API compatible with cloud.google.com/go/datastore. func IncompleteKey(kind string, parent *Key) *Key { - return &Key{ + k := &Key{ Kind: kind, Parent: parent, } + if parent != nil { + k.Namespace = parent.Namespace + } + return k } // Incomplete returns true if the key does not have an ID or Name. @@ -64,7 +77,7 @@ func (k *Key) Equal(other *Key) bool { if k == nil || other == nil { return false } - if k.Kind != other.Kind || k.Name != other.Name || k.ID != other.ID { + if k.Namespace != other.Namespace || k.Kind != other.Kind || k.Name != other.Name || k.ID != other.ID { return false } // Recursively check parent keys @@ -93,7 +106,11 @@ func (k *Key) String() string { parts = append([]string{part}, parts...) } - return "/" + strings.Join(parts, "/") + keyStr := "/" + strings.Join(parts, "/") + if k.Namespace != "" { + keyStr = fmt.Sprintf("[%s]%s", k.Namespace, keyStr) + } + return keyStr } // Encode returns an opaque representation of the key. @@ -172,9 +189,18 @@ func keyToJSON(key *Key) map[string]any { path = append(path, elem) } - return map[string]any{ + m := map[string]any{ "path": path, } + + // Add partitionId if namespace is present + if key.Namespace != "" { + m["partitionId"] = map[string]any{ + "namespaceId": key.Namespace, + } + } + + return m } // keyFromJSON converts a JSON key representation to a Key. @@ -189,6 +215,13 @@ func keyFromJSON(keyData any) (*Key, error) { return nil, errors.New("invalid key path") } + var namespace string + if pid, ok := keyMap["partitionId"].(map[string]any); ok { + if ns, ok := pid["namespaceId"].(string); ok { + namespace = ns + } + } + // Build key hierarchy from path elements var key *Key for _, elem := range path { @@ -198,7 +231,8 @@ func keyFromJSON(keyData any) (*Key, error) { } newKey := &Key{ - Parent: key, + Parent: key, + Namespace: namespace, } if kind, ok := elemMap["kind"].(string); ok { diff --git a/pkg/datastore/multierror_test.go b/pkg/datastore/multierror_test.go index d9dacdf..d501ab6 100644 --- a/pkg/datastore/multierror_test.go +++ b/pkg/datastore/multierror_test.go @@ -138,16 +138,16 @@ func TestMultiErrorGetMulti_NilKeys(t *testing.T) { t.Fatalf("Expected MultiError, got %T", err) } - if multiErr[0] != nil { - t.Errorf("Expected no error for key[0], got: %v", multiErr[0]) + if !errors.Is(multiErr[0], datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity for key[0], got: %v", multiErr[0]) } if multiErr[1] == nil { t.Error("Expected error for nil key at index 1") } else if !errors.Is(multiErr[1], datastore.ErrInvalidKey) { t.Errorf("Expected ErrInvalidKey for nil key, got: %v", multiErr[1]) } - if multiErr[2] != nil { - t.Errorf("Expected no error for key[2], got: %v", multiErr[2]) + if !errors.Is(multiErr[2], datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity for key[2], got: %v", multiErr[2]) } } diff --git a/pkg/datastore/namespace_test.go b/pkg/datastore/namespace_test.go new file mode 100644 index 0000000..f1602d3 --- /dev/null +++ b/pkg/datastore/namespace_test.go @@ -0,0 +1,119 @@ +package datastore_test + +import ( + "context" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestNamespaceIsolation(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type Data struct { + Value string + } + + // Create keys with different namespaces + keyDefault := datastore.NameKey("Data", "item1", nil) + + keyNS1 := datastore.NameKey("Data", "item1", nil) + keyNS1.Namespace = "ns1" + + keyNS2 := datastore.NameKey("Data", "item1", nil) + keyNS2.Namespace = "ns2" + + // Verify keys are different + if keyDefault.String() == keyNS1.String() { + t.Fatal("Keys should be different due to namespace") + } + + // Put entities + if _, err := client.Put(ctx, keyDefault, &Data{Value: "default"}); err != nil { + t.Fatal(err) + } + if _, err := client.Put(ctx, keyNS1, &Data{Value: "namespace1"}); err != nil { + t.Fatal(err) + } + if _, err := client.Put(ctx, keyNS2, &Data{Value: "namespace2"}); err != nil { + t.Fatal(err) + } + + // Get entities and verify isolation + var dest Data + + // Check Default + if err := client.Get(ctx, keyDefault, &dest); err != nil { + t.Errorf("Get default failed: %v", err) + } + if dest.Value != "default" { + t.Errorf("Expected 'default', got '%s'", dest.Value) + } + + // Check NS1 + if err := client.Get(ctx, keyNS1, &dest); err != nil { + t.Errorf("Get NS1 failed: %v", err) + } + if dest.Value != "namespace1" { + t.Errorf("Expected 'namespace1', got '%s'", dest.Value) + } + + // Check NS2 + if err := client.Get(ctx, keyNS2, &dest); err != nil { + t.Errorf("Get NS2 failed: %v", err) + } + if dest.Value != "namespace2" { + t.Errorf("Expected 'namespace2', got '%s'", dest.Value) + } +} + +func TestNamespaceQuery(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type Item struct { + N int + } + + // Create items in ns1 + for i := range 5 { + k := datastore.IncompleteKey("Item", nil) + k.Namespace = "ns1" + if _, err := client.Put(ctx, k, &Item{N: i}); err != nil { + t.Fatal(err) + } + } + + // Create items in default namespace + for i := range 3 { + k := datastore.IncompleteKey("Item", nil) + if _, err := client.Put(ctx, k, &Item{N: i}); err != nil { + t.Fatal(err) + } + } + + // Query ns1 + q1 := datastore.NewQuery("Item").Namespace("ns1") + var items1 []Item + if _, err := client.GetAll(ctx, q1, &items1); err != nil { + t.Fatalf("Query ns1 failed: %v", err) + } + if len(items1) != 5 { + t.Errorf("Expected 5 items in ns1, got %d", len(items1)) + } + + // Query default + qDef := datastore.NewQuery("Item") + var itemsDef []Item + if _, err := client.GetAll(ctx, qDef, &itemsDef); err != nil { + t.Fatalf("Query default failed: %v", err) + } + if len(itemsDef) != 3 { + t.Errorf("Expected 3 items in default, got %d", len(itemsDef)) + } +} diff --git a/pkg/datastore/operations.go b/pkg/datastore/operations.go index 726bf3b..414d187 100644 --- a/pkg/datastore/operations.go +++ b/pkg/datastore/operations.go @@ -11,6 +11,12 @@ import ( "github.com/codeGROOVE-dev/ds9/auth" ) +const ( + maxLookupBatch = 1000 + maxMutationBatch = 500 + maxAllocationBatch = 500 +) + // Get retrieves an entity by key and stores it in dst. // dst must be a pointer to a struct. // Returns ErrNoSuchEntity if the key is not found. @@ -186,11 +192,10 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { c.logger.WarnContext(ctx, "GetMulti called with nil key", "index", i) multiErr[i] = fmt.Errorf("%w: key at index %d cannot be nil", ErrInvalidKey, i) hasErr = true + } else { + multiErr[i] = ErrNoSuchEntity // Default to not found } } - if hasErr { - return multiErr - } // Decode into slice dstValue := reflect.ValueOf(dst) @@ -199,6 +204,7 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { } sliceType := dstValue.Elem().Type() + resultSlice := reflect.MakeSlice(sliceType, len(keys), len(keys)) token, err := auth.AccessToken(ctx) if err != nil { @@ -206,10 +212,79 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { return fmt.Errorf("failed to get access token: %w", err) } - // Build keys array - jsonKeys := make([]map[string]any, len(keys)) - for i, key := range keys { - jsonKeys[i] = keyToJSON(key) + // Process in batches + for i := 0; i < len(keys); i += maxLookupBatch { + end := i + maxLookupBatch + if end > len(keys) { + end = len(keys) + } + + batchKeys := keys[i:end] + batchIndices := make([]int, len(batchKeys)) + for k := range batchKeys { + batchIndices[k] = i + k + } + + // Skip batch if all keys are nil + allNil := true + for _, k := range batchKeys { + if k != nil { + allNil = false + break + } + } + if allNil { + continue + } + + if err := c.getMultiBatch(ctx, batchKeys, batchIndices, i, token, resultSlice, multiErr); err != nil { + // Batch failure handled inside getMultiBatch by updating multiErr + // We just check if we need to set hasErr + hasErr = true + } + } + + // Check if any errors occurred (including NoSuchEntity) + for _, e := range multiErr { + if e != nil { + hasErr = true + break + } + } + + // Set the result slice + dstValue.Elem().Set(resultSlice) + + if hasErr { + return multiErr + } + + c.logger.DebugContext(ctx, "entities retrieved successfully", "count", len(keys)) + return nil +} + +// getMultiBatch processes a single batch of keys for GetMulti. +func (c *Client) getMultiBatch( + ctx context.Context, + batchKeys []*Key, + batchIndices []int, + batchOffset int, + token string, + resultSlice reflect.Value, + multiErr MultiError, +) error { + // Build keys array for this batch + jsonKeys := make([]map[string]any, 0, len(batchKeys)) + keyMap := make(map[string][]int) // Map key string to original index + + for k, key := range batchKeys { + if key == nil { + continue + } + jsonKeys = append(jsonKeys, keyToJSON(key)) + keyStr := key.String() + idx := batchOffset + k + keyMap[keyStr] = append(keyMap[keyStr], idx) } reqBody := map[string]any{ @@ -225,11 +300,17 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { return fmt.Errorf("failed to marshal request: %w", err) } - // URL-encode project ID to prevent injection attacks reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID)) body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { - c.logger.ErrorContext(ctx, "lookup request failed", "error", err) + c.logger.ErrorContext(ctx, "lookup request failed for batch", "batch_start", batchOffset, "error", err) + // Mark all keys in this batch as failed + for _, idx := range batchIndices { + // Don't overwrite existing errors (like nil key) + if errors.Is(multiErr[idx], ErrNoSuchEntity) { + multiErr[idx] = err + } + } return err } @@ -244,22 +325,13 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { if err := json.Unmarshal(body, &result); err != nil { c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return fmt.Errorf("failed to parse response: %w", err) - } - - // Create a map of keys to their indices - keyMap := make(map[string][]int) - for i, key := range keys { - keyStr := key.String() - keyMap[keyStr] = append(keyMap[keyStr], i) - } - - // Create slice to hold results, sized to match keys - resultSlice := reflect.MakeSlice(sliceType, len(keys), len(keys)) - - // Mark all as not found initially - for i := range keys { - multiErr[i] = ErrNoSuchEntity + // Mark batch as failed + for _, idx := range batchIndices { + if errors.Is(multiErr[idx], ErrNoSuchEntity) { + multiErr[idx] = fmt.Errorf("failed to parse response: %w", err) + } + } + return err } // Process found entities @@ -267,7 +339,7 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { key, err := keyFromJSON(found.Entity["key"]) if err != nil { c.logger.ErrorContext(ctx, "failed to parse key from response", "error", err) - return err + continue } keyStr := key.String() @@ -281,44 +353,13 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { if err := decodeEntity(found.Entity, elem.Addr().Interface()); err != nil { c.logger.ErrorContext(ctx, "failed to decode entity", "index", index, "error", err) multiErr[index] = err - hasErr = true } else { multiErr[index] = nil // Success } } } + // Missing entities remain as ErrNoSuchEntity - // Process missing entities - they already have ErrNoSuchEntity set - for _, missing := range result.Missing { - key, err := keyFromJSON(missing.Entity["key"]) - if err != nil { - continue - } - keyStr := key.String() - if indices, ok := keyMap[keyStr]; ok { - for range indices { - hasErr = true - // multiErr[index] already set to ErrNoSuchEntity above - } - } - } - - // Check if any entities are still marked as missing - for i, e := range multiErr { - if errors.Is(e, ErrNoSuchEntity) { - hasErr = true - c.logger.DebugContext(ctx, "entity not found", "index", i, "key", keys[i].String()) - } - } - - // Set the result slice - dstValue.Elem().Set(resultSlice) - - if hasErr { - return multiErr - } - - c.logger.DebugContext(ctx, "entities retrieved successfully", "count", len(keys)) return nil } @@ -347,62 +388,83 @@ func (c *Client) PutMulti(ctx context.Context, keys []*Key, src any) ([]*Key, er return nil, fmt.Errorf("keys and src length mismatch: %d != %d", len(keys), v.Len()) } - // Validate keys and encode entities upfront multiErr := make(MultiError, len(keys)) hasErr := false - mutations := make([]map[string]any, len(keys)) - for i, key := range keys { - if key == nil { - c.logger.WarnContext(ctx, "PutMulti called with nil key", "index", i) - multiErr[i] = fmt.Errorf("%w: key at index %d cannot be nil", ErrInvalidKey, i) - hasErr = true - continue - } + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } - entity, err := encodeEntity(key, v.Index(i).Interface()) - if err != nil { - c.logger.ErrorContext(ctx, "failed to encode entity", "error", err, "index", i) - multiErr[i] = err - hasErr = true - continue + // Process in batches + for i := 0; i < len(keys); i += maxMutationBatch { + end := i + maxMutationBatch + if end > len(keys) { + end = len(keys) } - mutations[i] = map[string]any{ - "upsert": entity, + batchLen := end - i + mutations := make([]map[string]any, 0, batchLen) + batchIndices := make([]int, 0, batchLen) + + // Prepare batch mutations + for k := range batchLen { + idx := i + k + key := keys[idx] + + if key == nil { + c.logger.WarnContext(ctx, "PutMulti called with nil key", "index", idx) + multiErr[idx] = fmt.Errorf("%w: key at index %d cannot be nil", ErrInvalidKey, idx) + hasErr = true + continue + } + + entity, err := encodeEntity(key, v.Index(idx).Interface()) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "error", err, "index", idx) + multiErr[idx] = err + hasErr = true + continue + } + + mutations = append(mutations, map[string]any{ + "upsert": entity, + }) + batchIndices = append(batchIndices, idx) } - } - // If encoding failed, return MultiError - if hasErr { - return nil, multiErr - } + if len(mutations) == 0 { + continue + } - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return nil, fmt.Errorf("failed to get access token: %w", err) - } + reqBody := map[string]any{ + "mode": "NON_TRANSACTIONAL", + "mutations": mutations, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } - reqBody := map[string]any{ - "mode": "NON_TRANSACTIONAL", - "mutations": mutations, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) + if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { + c.logger.ErrorContext(ctx, "commit request failed", "error", err) + // Mark valid keys in this batch as failed + for _, idx := range batchIndices { + multiErr[idx] = err + hasErr = true + } + } } - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) - if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { - c.logger.ErrorContext(ctx, "commit request failed", "error", err) - return nil, err + if hasErr { + return keys, multiErr } c.logger.DebugContext(ctx, "entities stored successfully", "count", len(keys)) @@ -420,28 +482,8 @@ func (c *Client) DeleteMulti(ctx context.Context, keys []*Key) error { c.logger.DebugContext(ctx, "deleting multiple entities", "count", len(keys)) - // Validate keys upfront multiErr := make(MultiError, len(keys)) hasErr := false - mutations := make([]map[string]any, len(keys)) - - for i, key := range keys { - if key == nil { - c.logger.WarnContext(ctx, "DeleteMulti called with nil key", "index", i) - multiErr[i] = fmt.Errorf("%w: key at index %d cannot be nil", ErrInvalidKey, i) - hasErr = true - continue - } - - mutations[i] = map[string]any{ - "delete": keyToJSON(key), - } - } - - // If validation failed, return MultiError - if hasErr { - return multiErr - } token, err := auth.AccessToken(ctx) if err != nil { @@ -449,25 +491,65 @@ func (c *Client) DeleteMulti(ctx context.Context, keys []*Key) error { return fmt.Errorf("failed to get access token: %w", err) } - reqBody := map[string]any{ - "mode": "NON_TRANSACTIONAL", - "mutations": mutations, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } + // Process in batches + for i := 0; i < len(keys); i += maxMutationBatch { + end := i + maxMutationBatch + if end > len(keys) { + end = len(keys) + } - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return fmt.Errorf("failed to marshal request: %w", err) + batchLen := end - i + mutations := make([]map[string]any, 0, batchLen) + batchIndices := make([]int, 0, batchLen) + + for k := range batchLen { + idx := i + k + key := keys[idx] + + if key == nil { + c.logger.WarnContext(ctx, "DeleteMulti called with nil key", "index", idx) + multiErr[idx] = fmt.Errorf("%w: key at index %d cannot be nil", ErrInvalidKey, idx) + hasErr = true + continue + } + + mutations = append(mutations, map[string]any{ + "delete": keyToJSON(key), + }) + batchIndices = append(batchIndices, idx) + } + + if len(mutations) == 0 { + continue + } + + reqBody := map[string]any{ + "mode": "NON_TRANSACTIONAL", + "mutations": mutations, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return fmt.Errorf("failed to marshal request: %w", err) + } + + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) + if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { + c.logger.ErrorContext(ctx, "delete request failed", "error", err) + // Mark valid keys in this batch as failed + for _, idx := range batchIndices { + multiErr[idx] = err + hasErr = true + } + } } - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) - if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { - c.logger.ErrorContext(ctx, "delete request failed", "error", err) - return err + if hasErr { + return multiErr } c.logger.DebugContext(ctx, "entities deleted successfully", "count", len(keys)) @@ -536,50 +618,57 @@ func (c *Client) AllocateIDs(ctx context.Context, keys []*Key) ([]*Key, error) { return nil, fmt.Errorf("failed to get access token: %w", err) } - // Build request with incomplete keys - reqKeys := make([]map[string]any, len(incompleteKeys)) - for i, key := range incompleteKeys { - reqKeys[i] = keyToJSON(key) - } + // Process in batches + allocatedKeys := make([]*Key, len(incompleteKeys)) - reqBody := map[string]any{ - "keys": reqKeys, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } + for i := 0; i < len(incompleteKeys); i += maxAllocationBatch { + end := i + maxAllocationBatch + if end > len(incompleteKeys) { + end = len(incompleteKeys) + } - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) - } + batchKeys := incompleteKeys[i:end] + reqKeys := make([]map[string]any, len(batchKeys)) + for k, key := range batchKeys { + reqKeys[k] = keyToJSON(key) + } - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", c.baseURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "allocateIds request failed", "error", err) - return nil, err - } + reqBody := map[string]any{ + "keys": reqKeys, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } - var resp struct { - Keys []map[string]any `json:"keys"` - } - if err := json.Unmarshal(body, &resp); err != nil { - c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return nil, fmt.Errorf("failed to parse allocateIds response: %w", err) - } + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } - // Parse allocated keys - allocatedKeys := make([]*Key, len(resp.Keys)) - for i, keyData := range resp.Keys { - key, err := keyFromJSON(keyData) + reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", c.baseURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { - c.logger.ErrorContext(ctx, "failed to parse allocated key", "index", i, "error", err) - return nil, fmt.Errorf("failed to parse allocated key at index %d: %w", i, err) + c.logger.ErrorContext(ctx, "allocateIds request failed", "error", err) + return nil, err + } + + var resp struct { + Keys []map[string]any `json:"keys"` + } + if err := json.Unmarshal(body, &resp); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return nil, fmt.Errorf("failed to parse allocateIds response: %w", err) + } + + for k, keyData := range resp.Keys { + key, err := keyFromJSON(keyData) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse allocated key", "index", i+k, "error", err) + return nil, fmt.Errorf("failed to parse allocated key at index %d: %w", i+k, err) + } + allocatedKeys[i+k] = key } - allocatedKeys[i] = key } // Create result slice with allocated keys in correct positions diff --git a/pkg/datastore/operations_misc_test.go b/pkg/datastore/operations_misc_test.go index 9509fac..a9f5c5d 100644 --- a/pkg/datastore/operations_misc_test.go +++ b/pkg/datastore/operations_misc_test.go @@ -1377,11 +1377,12 @@ func TestArraySliceSupport(t *testing.T) { t.Fatalf("Get failed: %v", err) } - if result.Strings == nil || len(result.Strings) != 0 { - t.Errorf("Expected empty string slice, got %v", result.Strings) + // With omitempty tag, empty slices are not stored, so we get nil back + if result.Strings != nil { + t.Errorf("Expected nil string slice (omitempty), got %v", result.Strings) } - if result.Ints == nil || len(result.Ints) != 0 { - t.Errorf("Expected empty int slice, got %v", result.Ints) + if result.Ints != nil { + t.Errorf("Expected nil int slice (omitempty), got %v", result.Ints) } }) diff --git a/pkg/datastore/query.go b/pkg/datastore/query.go index 6a0ef2a..bbd3004 100644 --- a/pkg/datastore/query.go +++ b/pkg/datastore/query.go @@ -194,18 +194,11 @@ func buildQueryMap(query *Query) map[string]any { "kind": []map[string]any{{"name": query.kind}}, } - // Add namespace via partition ID if specified - if query.namespace != "" { - queryMap["partitionId"] = map[string]any{ - "namespaceId": query.namespace, - } - } - // Add filters if len(query.filters) > 0 { var compositeFilters []map[string]any for _, f := range query.filters { - encodedVal, err := encodeValue(f.value) + encodedVal, err := encodeAny(f.value) if err != nil { // Skip invalid filters continue @@ -341,6 +334,9 @@ func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) { if c.databaseID != "" { reqBody["databaseId"] = c.databaseID } + if q.namespace != "" { + reqBody["partitionId"] = map[string]any{"namespaceId": q.namespace} + } jsonData, err := json.Marshal(reqBody) if err != nil { @@ -403,6 +399,9 @@ func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, err if c.databaseID != "" { reqBody["databaseId"] = c.databaseID } + if query.namespace != "" { + reqBody["partitionId"] = map[string]any{"namespaceId": query.namespace} + } jsonData, err := json.Marshal(reqBody) if err != nil { @@ -431,8 +430,9 @@ func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, err return nil, fmt.Errorf("failed to parse response: %w", err) } - // For KeysOnly queries, dst can be nil - just return keys - if query.keysOnly && dst == nil { + // For KeysOnly queries, skip entity decoding - just return keys + // The Datastore API returns entities without properties for keys-only queries + if query.keysOnly { keys := make([]*Key, 0, len(result.Batch.EntityResults)) for _, er := range result.Batch.EntityResults { key, err := keyFromJSON(er.Entity["key"]) @@ -513,6 +513,9 @@ func (c *Client) Count(ctx context.Context, q *Query) (int, error) { if c.databaseID != "" { reqBody["databaseId"] = c.databaseID } + if q.namespace != "" { + reqBody["partitionId"] = map[string]any{"namespaceId": q.namespace} + } jsonData, err := json.Marshal(reqBody) if err != nil { diff --git a/pkg/mock/mock.go b/pkg/mock/mock.go index 0ecb7b4..52879b0 100644 --- a/pkg/mock/mock.go +++ b/pkg/mock/mock.go @@ -200,32 +200,11 @@ func (s *Store) handleLookup(w http.ResponseWriter, r *http.Request) { defer s.mu.RUnlock() for _, keyData := range req.Keys { - path, ok := keyData["path"].([]any) - if !ok { - continue - } - if len(path) == 0 { - continue - } - pathElem, ok := path[0].(map[string]any) - if !ok { - continue - } - kind, ok := pathElem["kind"].(string) + keyStr, ok := s.extractKeyString(keyData) if !ok { continue } - // Handle both name and ID keys - var keyStr string - if name, ok := pathElem["name"].(string); ok { - keyStr = kind + "/" + name - } else if id, ok := pathElem["id"].(string); ok { - keyStr = kind + "/" + id - } else { - continue - } - if entity, exists := s.entities[keyStr]; exists { found = append(found, map[string]any{ "entity": entity, @@ -517,12 +496,20 @@ func (s *Store) resolveKey(keyData map[string]any) (keyStr string, updatedKey ma return "", nil, false } + // Extract namespace + namespace := "" + if pid, ok := keyData["partitionId"].(map[string]any); ok { + if ns, ok := pid["namespaceId"].(string); ok { + namespace = ns + } + } + // Handle both name and ID keys if name, ok := pathElem["name"].(string); ok { - return kind + "/" + name, keyData, true + return namespace + "!" + kind + "/" + name, keyData, true } if id, ok := pathElem["id"].(string); ok { - return kind + "/" + id, keyData, true + return namespace + "!" + kind + "/" + id, keyData, true } // Incomplete key - allocate an ID @@ -530,7 +517,7 @@ func (s *Store) resolveKey(keyData map[string]any) (keyStr string, updatedKey ma allocatedID := strconv.FormatInt(s.nextID, 10) pathElem["id"] = allocatedID - return kind + "/" + allocatedID, keyData, true + return namespace + "!" + kind + "/" + allocatedID, keyData, true } // extractKeyString extracts the key string from key data. @@ -548,12 +535,20 @@ func (*Store) extractKeyString(keyData map[string]any) (string, bool) { return "", false } + // Extract namespace + namespace := "" + if pid, ok := keyData["partitionId"].(map[string]any); ok { + if ns, ok := pid["namespaceId"].(string); ok { + namespace = ns + } + } + // Handle both name and ID keys if name, ok := pathElem["name"].(string); ok { - return kind + "/" + name, true + return namespace + "!" + kind + "/" + name, true } if id, ok := pathElem["id"].(string); ok { - return kind + "/" + id, true + return namespace + "!" + kind + "/" + id, true } return "", false } @@ -566,11 +561,30 @@ type queryResult struct { entity map[string]any } +// isKeysOnlyQuery checks if the query has a projection for only __key__. +func isKeysOnlyQuery(query map[string]any) bool { + projection, ok := query["projection"].([]any) + if !ok || len(projection) != 1 { + return false + } + proj, ok := projection[0].(map[string]any) + if !ok { + return false + } + prop, ok := proj["property"].(map[string]any) + if !ok { + return false + } + name, ok := prop["name"].(string) + return ok && name == "__key__" +} + // handleRunQuery handles query requests. func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { var req struct { - Query map[string]any `json:"query"` - DatabaseID string `json:"databaseId"` + Query map[string]any `json:"query"` + PartitionID map[string]any `json:"partitionId"` + DatabaseID string `json:"databaseId"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -604,6 +618,14 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { return } + // Filter by namespace + namespace := "" + if req.PartitionID != nil { + if ns, ok := req.PartitionID["namespaceId"].(string); ok { + namespace = ns + } + } + var limit int if l, ok := query["limit"].(float64); ok { limit = int(l) @@ -644,7 +666,15 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { continue } - if entityKind == kind { + // Check kind and namespace + entityNamespace := "" + if pid, ok := keyData["partitionId"].(map[string]any); ok { + if ns, ok := pid["namespaceId"].(string); ok { + entityNamespace = ns + } + } + + if entityKind == kind && entityNamespace == namespace { // Apply filters if present if filterMap, hasFilter := query["filter"].(map[string]any); hasFilter { if !matchesFilter(entity, filterMap) { @@ -679,12 +709,24 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { matches = matches[:limit] } + // Check if this is a keys-only query (projection contains only __key__) + keysOnly := isKeysOnlyQuery(query) + // Build results results := make([]any, 0, len(matches)) for _, m := range matches { - results = append(results, map[string]any{ - "entity": m.entity, - }) + if keysOnly { + // For keys-only queries, return entity with only the key (no properties) + results = append(results, map[string]any{ + "entity": map[string]any{ + "key": m.entity["key"], + }, + }) + } else { + results = append(results, map[string]any{ + "entity": m.entity, + }) + } } // Generate cursor for pagination @@ -993,6 +1035,29 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool { } filterValue := propFilter["value"] + // Handle HAS_ANCESTOR + if operator == "HAS_ANCESTOR" { + ancestorKeyData, ok := filterValue.(map[string]any) + if !ok { + // Try keyValue if wrapped + kv, ok := filterValue.(map[string]any) + if !ok { + return false + } + ak, ok := kv["keyValue"].(map[string]any) + if !ok { + return false + } + ancestorKeyData = ak + } + // Check if entity key has prefix of ancestor key path + entityKeyData, ok := entity["key"].(map[string]any) + if !ok { + return false + } + return isAncestor(ancestorKeyData, entityKeyData) + } + // Get entity properties properties, ok := entity["properties"].(map[string]any) if !ok { @@ -1102,6 +1167,42 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool { return true // No filter or unrecognized filter, allow all } +// isAncestor checks if ancestorKey is a prefix of entityKey. +func isAncestor(ancestorKey, entityKey map[string]any) bool { + ancPath, ok1 := ancestorKey["path"].([]any) + entPath, ok2 := entityKey["path"].([]any) + + if !ok1 || !ok2 { + return false + } + + if len(ancPath) > len(entPath) { + return false + } + + // Check equality of path elements + for i := range ancPath { + ap, ok1 := ancPath[i].(map[string]any) + ep, ok2 := entPath[i].(map[string]any) + + if !ok1 || !ok2 { + return false + } + + if ap["kind"] != ep["kind"] { + return false + } + if ap["name"] != ep["name"] { + return false + } + if ap["id"] != ep["id"] { + return false + } + } + + return true +} + // handleRunAggregationQuery handles :runAggregationQuery requests. func (s *Store) handleRunAggregationQuery(w http.ResponseWriter, r *http.Request) { var req struct { //nolint:govet // Local anonymous struct for JSON unmarshaling