diff --git a/.asf.yaml b/.asf.yaml index 2820e1da..47dcfce3 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -29,11 +29,13 @@ github: rebase: false squash: true features: + discussions: true issues: true protected_branches: main: required_linear_history: true notifications: + discussions: user@arrow.apache.org commits: commits@arrow.apache.org issues_status: issues@arrow.apache.org issues_comment: github@arrow.apache.org diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4636eb32..ef940374 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -40,11 +40,11 @@ jobs: with: fetch-depth: 0 - name: Setup Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: 3.12 - name: Setup Go - uses: actions/setup-go@f111f3307d8850f501ac008e886eec1fd1932a34 # v5.3.0 + uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version: '1.23' cache: true diff --git a/.github/workflows/rc.yml b/.github/workflows/rc.yml index a94fb27f..0d81cb0a 100644 --- a/.github/workflows/rc.yml +++ b/.github/workflows/rc.yml @@ -68,7 +68,7 @@ jobs: - name: Audit run: | dev/release/run_rat.sh "${TAR_GZ}" - - uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4.6.1 + - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: archive path: | @@ -89,7 +89,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: recursive - - uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806 # v4.1.9 + - uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0 with: name: archive - name: Verify @@ -119,7 +119,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: recursive - - uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806 # v4.1.9 + - uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0 with: name: archive - name: Upload diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0cdd4ed1..1ff24c36 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -63,7 +63,7 @@ jobs: with: submodules: recursive - name: Login to GitHub Container registry - uses: docker/login-action@v3 + uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0 with: registry: ghcr.io username: ${{ github.actor }} @@ -176,7 +176,7 @@ jobs: with: submodules: recursive - name: Setup Go - uses: actions/setup-go@f111f3307d8850f501ac008e886eec1fd1932a34 # v5.3.0 + uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version: ${{ matrix.go }} cache: true @@ -205,7 +205,7 @@ jobs: with: submodules: recursive - name: Setup Go - uses: actions/setup-go@f111f3307d8850f501ac008e886eec1fd1932a34 # v5.3.0 + uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version: ${{ matrix.go }} cache: true @@ -214,7 +214,7 @@ jobs: run: brew install apache-arrow - name: Setup PKG_CONFIG_PATH run: | - echo "PKG_CONFIG_PATH=$(brew --prefix openssl@3)/lib/pkgconfig:$PKG_CONFIG_PATH" >> $GITHUB_ENV + echo "PKG_CONFIG_PATH=$(brew --prefix openssl@3)/lib/pkgconfig:$(brew --prefix)/lib/pkgconfig:$PKG_CONFIG_PATH" >> $GITHUB_ENV - name: Build run: | ci/scripts/build.sh $(pwd) @@ -237,7 +237,7 @@ jobs: with: submodules: recursive - name: Setup Go - uses: actions/setup-go@f111f3307d8850f501ac008e886eec1fd1932a34 # v5.3.0 + uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version: ${{ matrix.go }} cache: true @@ -291,7 +291,7 @@ jobs: echo "CGO_LDFLAGS=-g -O2 -L$(cygpath --windows ${MINGW_PREFIX}/lib) -L$(cygpath --windows ${MINGW_PREFIX}/bin)" >> $GITHUB_ENV echo "MINGW_PREFIX=$(cygpath --windows ${MINGW_PREFIX})" >> $GITHUB_ENV - name: Setup Go - uses: actions/setup-go@f111f3307d8850f501ac008e886eec1fd1932a34 # v5.3.0 + uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version: "${{ env.GO_VERSION }}" cache: true @@ -315,7 +315,7 @@ jobs: run: | (. .env && echo "GO_VERSION=${GO}") >> $GITHUB_ENV - name: Setup Go - uses: actions/setup-go@f111f3307d8850f501ac008e886eec1fd1932a34 # v5.3.0 + uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version: "${{ env.GO_VERSION }}" cache: true @@ -384,7 +384,7 @@ jobs: key: integration-conda-${{ hashFiles('cpp/**') }} restore-keys: conda- - name: Setup Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: 3.12 - name: Setup Archery diff --git a/arrow/Makefile b/arrow/Makefile index 9c4a2326..c7e32709 100644 --- a/arrow/Makefile +++ b/arrow/Makefile @@ -30,7 +30,7 @@ assembly: @$(MAKE) -C math assembly generate: bin/tmpl - bin/tmpl -i -data=numeric.tmpldata type_traits_numeric.gen.go.tmpl type_traits_numeric.gen_test.go.tmpl array/numeric.gen.go.tmpl array/numericbuilder.gen_test.go.tmpl array/numericbuilder.gen.go.tmpl array/bufferbuilder_numeric.gen.go.tmpl + bin/tmpl -i -data=numeric.tmpldata type_traits_numeric.gen.go.tmpl type_traits_numeric.gen_test.go.tmpl array/numericbuilder.gen_test.go.tmpl array/numericbuilder.gen.go.tmpl array/bufferbuilder_numeric.gen.go.tmpl bin/tmpl -i -data=datatype_numeric.gen.go.tmpldata datatype_numeric.gen.go.tmpl @$(MAKE) -C math generate diff --git a/arrow/array.go b/arrow/array.go index df186f2d..d42ca6d0 100644 --- a/arrow/array.go +++ b/arrow/array.go @@ -127,3 +127,15 @@ type Array interface { // When the reference count goes to zero, the memory is freed. Release() } + +// ValueType is a generic constraint for valid Arrow primitive types +type ValueType interface { + bool | FixedWidthType | string | []byte +} + +// TypedArray is an interface representing an Array of a particular type +// allowing for easy propagation of generics +type TypedArray[T ValueType] interface { + Array + Value(int) T +} diff --git a/arrow/array/array.go b/arrow/array/array.go index 6e281a43..947b44f2 100644 --- a/arrow/array/array.go +++ b/arrow/array/array.go @@ -35,7 +35,7 @@ const ( ) type array struct { - refCount int64 + refCount atomic.Int64 data *Data nullBitmapBytes []byte } @@ -43,16 +43,16 @@ type array struct { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (a *array) Retain() { - atomic.AddInt64(&a.refCount, 1) + a.refCount.Add(1) } // Release decreases the reference count by 1. // Release may be called simultaneously from multiple goroutines. // When the reference count goes to zero, the memory is freed. func (a *array) Release() { - debug.Assert(atomic.LoadInt64(&a.refCount) > 0, "too many releases") + debug.Assert(a.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&a.refCount, -1) == 0 { + if a.refCount.Add(-1) == 0 { a.data.Release() a.data, a.nullBitmapBytes = nil, nil } @@ -109,9 +109,7 @@ func (a *array) Offset() int { type arrayConstructorFn func(arrow.ArrayData) arrow.Array -var ( - makeArrayFn [64]arrayConstructorFn -) +var makeArrayFn [64]arrayConstructorFn func invalidDataType(data arrow.ArrayData) arrow.Array { panic("invalid data type: " + data.DataType().ID().String()) diff --git a/arrow/array/binary.go b/arrow/array/binary.go index 1af7631b..5fef60ec 100644 --- a/arrow/array/binary.go +++ b/arrow/array/binary.go @@ -45,7 +45,7 @@ type Binary struct { // NewBinaryData constructs a new Binary array from data. func NewBinaryData(data arrow.ArrayData) *Binary { a := &Binary{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -189,7 +189,7 @@ type LargeBinary struct { func NewLargeBinaryData(data arrow.ArrayData) *LargeBinary { a := &LargeBinary{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -208,6 +208,7 @@ func (a *LargeBinary) ValueStr(i int) string { } return base64.StdEncoding.EncodeToString(a.Value(i)) } + func (a *LargeBinary) ValueString(i int) string { b := a.Value(i) return *(*string)(unsafe.Pointer(&b)) @@ -333,7 +334,7 @@ type BinaryView struct { func NewBinaryViewData(data arrow.ArrayData) *BinaryView { a := &BinaryView{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -450,4 +451,8 @@ var ( _ BinaryLike = (*Binary)(nil) _ BinaryLike = (*LargeBinary)(nil) + + _ arrow.TypedArray[[]byte] = (*Binary)(nil) + _ arrow.TypedArray[[]byte] = (*LargeBinary)(nil) + _ arrow.TypedArray[[]byte] = (*BinaryView)(nil) ) diff --git a/arrow/array/binarybuilder.go b/arrow/array/binarybuilder.go index 794ac688..8b162c77 100644 --- a/arrow/array/binarybuilder.go +++ b/arrow/array/binarybuilder.go @@ -22,7 +22,6 @@ import ( "fmt" "math" "reflect" - "sync/atomic" "unsafe" "github.com/apache/arrow-go/v18/arrow" @@ -72,8 +71,8 @@ func NewBinaryBuilder(mem memory.Allocator, dtype arrow.BinaryDataType) *BinaryB offsetByteWidth = arrow.Int64SizeBytes } - b := &BinaryBuilder{ - builder: builder{refCount: 1, mem: mem}, + bb := &BinaryBuilder{ + builder: builder{mem: mem}, dtype: dtype, offsets: offsets, values: newByteBufferBuilder(mem), @@ -82,7 +81,8 @@ func NewBinaryBuilder(mem memory.Allocator, dtype arrow.BinaryDataType) *BinaryB offsetByteWidth: offsetByteWidth, getOffsetVal: getOffsetVal, } - return b + bb.builder.refCount.Add(1) + return bb } func (b *BinaryBuilder) Type() arrow.DataType { return b.dtype } @@ -91,9 +91,9 @@ func (b *BinaryBuilder) Type() arrow.DataType { return b.dtype } // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (b *BinaryBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -387,18 +387,19 @@ type BinaryViewBuilder struct { } func NewBinaryViewBuilder(mem memory.Allocator) *BinaryViewBuilder { - return &BinaryViewBuilder{ + bvb := &BinaryViewBuilder{ dtype: arrow.BinaryTypes.BinaryView, builder: builder{ - refCount: 1, - mem: mem, + mem: mem, }, blockBuilder: multiBufferBuilder{ - refCount: 1, blockSize: dfltBlockSize, mem: mem, }, } + bvb.builder.refCount.Add(1) + bvb.blockBuilder.refCount.Add(1) + return bvb } func (b *BinaryViewBuilder) SetBlockSize(sz uint) { @@ -408,9 +409,9 @@ func (b *BinaryViewBuilder) SetBlockSize(sz uint) { func (b *BinaryViewBuilder) Type() arrow.DataType { return b.dtype } func (b *BinaryViewBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) != 0 { + if b.refCount.Add(-1) != 0 { return } @@ -673,7 +674,8 @@ func (b *BinaryViewBuilder) newData() (data *Data) { dataBuffers := b.blockBuilder.Finish() data = NewData(b.dtype, b.length, append([]*memory.Buffer{ - b.nullBitmap, b.data}, dataBuffers...), nil, b.nulls, 0) + b.nullBitmap, b.data, + }, dataBuffers...), nil, b.nulls, 0) b.reset() if b.data != nil { diff --git a/arrow/array/boolean.go b/arrow/array/boolean.go index fb2dba73..1b28a9f4 100644 --- a/arrow/array/boolean.go +++ b/arrow/array/boolean.go @@ -44,7 +44,7 @@ func NewBoolean(length int, data *memory.Buffer, nullBitmap *memory.Buffer, null func NewBooleanData(data arrow.ArrayData) *Boolean { a := &Boolean{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -122,5 +122,6 @@ func arrayEqualBoolean(left, right *Boolean) bool { } var ( - _ arrow.Array = (*Boolean)(nil) + _ arrow.Array = (*Boolean)(nil) + _ arrow.TypedArray[bool] = (*Boolean)(nil) ) diff --git a/arrow/array/booleanbuilder.go b/arrow/array/booleanbuilder.go index 951fe3a9..a277ffd2 100644 --- a/arrow/array/booleanbuilder.go +++ b/arrow/array/booleanbuilder.go @@ -21,7 +21,6 @@ import ( "fmt" "reflect" "strconv" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" @@ -38,7 +37,9 @@ type BooleanBuilder struct { } func NewBooleanBuilder(mem memory.Allocator) *BooleanBuilder { - return &BooleanBuilder{builder: builder{refCount: 1, mem: mem}} + bb := &BooleanBuilder{builder: builder{mem: mem}} + bb.builder.refCount.Add(1) + return bb } func (b *BooleanBuilder) Type() arrow.DataType { return arrow.FixedWidthTypes.Boolean } @@ -47,9 +48,9 @@ func (b *BooleanBuilder) Type() arrow.DataType { return arrow.FixedWidthTypes.Bo // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (b *BooleanBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -258,6 +259,4 @@ func (b *BooleanBuilder) Value(i int) bool { return bitutil.BitIsSet(b.rawData, i) } -var ( - _ Builder = (*BooleanBuilder)(nil) -) +var _ Builder = (*BooleanBuilder)(nil) diff --git a/arrow/array/bufferbuilder.go b/arrow/array/bufferbuilder.go index 085d43ef..bc784d6a 100644 --- a/arrow/array/bufferbuilder.go +++ b/arrow/array/bufferbuilder.go @@ -43,7 +43,7 @@ type bufBuilder interface { // A bufferBuilder provides common functionality for populating memory with a sequence of type-specific values. // Specialized implementations provide type-safe APIs for appending and accessing the memory. type bufferBuilder struct { - refCount int64 + refCount atomic.Int64 mem memory.Allocator buffer *memory.Buffer length int @@ -55,16 +55,16 @@ type bufferBuilder struct { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (b *bufferBuilder) Retain() { - atomic.AddInt64(&b.refCount, 1) + b.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (b *bufferBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.buffer != nil { b.buffer.Release() b.buffer, b.bytes = nil, nil @@ -155,7 +155,7 @@ func (b *bufferBuilder) unsafeAppend(data []byte) { } type multiBufferBuilder struct { - refCount int64 + refCount atomic.Int64 blockSize int mem memory.Allocator @@ -166,16 +166,16 @@ type multiBufferBuilder struct { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (b *multiBufferBuilder) Retain() { - atomic.AddInt64(&b.refCount, 1) + b.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (b *multiBufferBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { b.Reset() } } diff --git a/arrow/array/bufferbuilder_byte.go b/arrow/array/bufferbuilder_byte.go index 78bb938e..61431b71 100644 --- a/arrow/array/bufferbuilder_byte.go +++ b/arrow/array/bufferbuilder_byte.go @@ -23,7 +23,9 @@ type byteBufferBuilder struct { } func newByteBufferBuilder(mem memory.Allocator) *byteBufferBuilder { - return &byteBufferBuilder{bufferBuilder: bufferBuilder{refCount: 1, mem: mem}} + bbb := &byteBufferBuilder{bufferBuilder: bufferBuilder{mem: mem}} + bbb.bufferBuilder.refCount.Add(1) + return bbb } func (b *byteBufferBuilder) Values() []byte { return b.Bytes() } diff --git a/arrow/array/bufferbuilder_numeric.gen.go b/arrow/array/bufferbuilder_numeric.gen.go index 3812c5e7..e887fbf1 100644 --- a/arrow/array/bufferbuilder_numeric.gen.go +++ b/arrow/array/bufferbuilder_numeric.gen.go @@ -29,7 +29,9 @@ type int64BufferBuilder struct { } func newInt64BufferBuilder(mem memory.Allocator) *int64BufferBuilder { - return &int64BufferBuilder{bufferBuilder: bufferBuilder{refCount: 1, mem: mem}} + b := &int64BufferBuilder{bufferBuilder: bufferBuilder{mem: mem}} + b.refCount.Add(1) + return b } // AppendValues appends the contents of v to the buffer, growing the buffer as needed. @@ -62,7 +64,9 @@ type int32BufferBuilder struct { } func newInt32BufferBuilder(mem memory.Allocator) *int32BufferBuilder { - return &int32BufferBuilder{bufferBuilder: bufferBuilder{refCount: 1, mem: mem}} + b := &int32BufferBuilder{bufferBuilder: bufferBuilder{mem: mem}} + b.refCount.Add(1) + return b } // AppendValues appends the contents of v to the buffer, growing the buffer as needed. @@ -95,7 +99,9 @@ type int8BufferBuilder struct { } func newInt8BufferBuilder(mem memory.Allocator) *int8BufferBuilder { - return &int8BufferBuilder{bufferBuilder: bufferBuilder{refCount: 1, mem: mem}} + b := &int8BufferBuilder{bufferBuilder: bufferBuilder{mem: mem}} + b.refCount.Add(1) + return b } // AppendValues appends the contents of v to the buffer, growing the buffer as needed. diff --git a/arrow/array/bufferbuilder_numeric.gen.go.tmpl b/arrow/array/bufferbuilder_numeric.gen.go.tmpl index c3c39de1..35820575 100644 --- a/arrow/array/bufferbuilder_numeric.gen.go.tmpl +++ b/arrow/array/bufferbuilder_numeric.gen.go.tmpl @@ -30,7 +30,9 @@ type {{$TypeNamePrefix}}BufferBuilder struct { } func new{{.Name}}BufferBuilder(mem memory.Allocator) *{{$TypeNamePrefix}}BufferBuilder { - return &{{$TypeNamePrefix}}BufferBuilder{bufferBuilder:bufferBuilder{refCount: 1, mem:mem}} + b := &{{$TypeNamePrefix}}BufferBuilder{bufferBuilder:bufferBuilder{mem:mem}} + b.refCount.Add(1) + return b } // AppendValues appends the contents of v to the buffer, growing the buffer as needed. diff --git a/arrow/array/builder.go b/arrow/array/builder.go index a2a40d48..0b3a4e9a 100644 --- a/arrow/array/builder.go +++ b/arrow/array/builder.go @@ -102,7 +102,7 @@ type Builder interface { // builder provides common functionality for managing the validity bitmap (nulls) when building arrays. type builder struct { - refCount int64 + refCount atomic.Int64 mem memory.Allocator nullBitmap *memory.Buffer nulls int @@ -113,7 +113,7 @@ type builder struct { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (b *builder) Retain() { - atomic.AddInt64(&b.refCount, 1) + b.refCount.Add(1) } // Len returns the number of elements in the array builder. @@ -176,13 +176,13 @@ func (b *builder) resize(newBits int, init func(int)) { } func (b *builder) reserve(elements int, resize func(int)) { - if b.nullBitmap == nil { - b.nullBitmap = memory.NewResizableBuffer(b.mem) - } if b.length+elements > b.capacity { newCap := bitutil.NextPowerOf2(b.length + elements) resize(newCap) } + if b.nullBitmap == nil { + b.nullBitmap = memory.NewResizableBuffer(b.mem) + } } // unsafeAppendBoolsToBitmap appends the contents of valid to the validity bitmap. diff --git a/arrow/array/compare.go b/arrow/array/compare.go index e412febf..fda15f50 100644 --- a/arrow/array/compare.go +++ b/arrow/array/compare.go @@ -240,37 +240,37 @@ func Equal(left, right arrow.Array) bool { return arrayEqualStringView(l, r) case *Int8: r := right.(*Int8) - return arrayEqualInt8(l, r) + return arrayEqualFixedWidth(l, r) case *Int16: r := right.(*Int16) - return arrayEqualInt16(l, r) + return arrayEqualFixedWidth(l, r) case *Int32: r := right.(*Int32) - return arrayEqualInt32(l, r) + return arrayEqualFixedWidth(l, r) case *Int64: r := right.(*Int64) - return arrayEqualInt64(l, r) + return arrayEqualFixedWidth(l, r) case *Uint8: r := right.(*Uint8) - return arrayEqualUint8(l, r) + return arrayEqualFixedWidth(l, r) case *Uint16: r := right.(*Uint16) - return arrayEqualUint16(l, r) + return arrayEqualFixedWidth(l, r) case *Uint32: r := right.(*Uint32) - return arrayEqualUint32(l, r) + return arrayEqualFixedWidth(l, r) case *Uint64: r := right.(*Uint64) - return arrayEqualUint64(l, r) + return arrayEqualFixedWidth(l, r) case *Float16: r := right.(*Float16) - return arrayEqualFloat16(l, r) + return arrayEqualFixedWidth(l, r) case *Float32: r := right.(*Float32) - return arrayEqualFloat32(l, r) + return arrayEqualFixedWidth(l, r) case *Float64: r := right.(*Float64) - return arrayEqualFloat64(l, r) + return arrayEqualFixedWidth(l, r) case *Decimal32: r := right.(*Decimal32) return arrayEqualDecimal(l, r) @@ -285,16 +285,16 @@ func Equal(left, right arrow.Array) bool { return arrayEqualDecimal(l, r) case *Date32: r := right.(*Date32) - return arrayEqualDate32(l, r) + return arrayEqualFixedWidth(l, r) case *Date64: r := right.(*Date64) - return arrayEqualDate64(l, r) + return arrayEqualFixedWidth(l, r) case *Time32: r := right.(*Time32) - return arrayEqualTime32(l, r) + return arrayEqualFixedWidth(l, r) case *Time64: r := right.(*Time64) - return arrayEqualTime64(l, r) + return arrayEqualFixedWidth(l, r) case *Timestamp: r := right.(*Timestamp) return arrayEqualTimestamp(l, r) @@ -327,7 +327,7 @@ func Equal(left, right arrow.Array) bool { return arrayEqualMonthDayNanoInterval(l, r) case *Duration: r := right.(*Duration) - return arrayEqualDuration(l, r) + return arrayEqualFixedWidth(l, r) case *Map: r := right.(*Map) return arrayEqualMap(l, r) @@ -502,28 +502,28 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { return arrayApproxEqualStringView(l, r) case *Int8: r := right.(*Int8) - return arrayEqualInt8(l, r) + return arrayEqualFixedWidth(l, r) case *Int16: r := right.(*Int16) - return arrayEqualInt16(l, r) + return arrayEqualFixedWidth(l, r) case *Int32: r := right.(*Int32) - return arrayEqualInt32(l, r) + return arrayEqualFixedWidth(l, r) case *Int64: r := right.(*Int64) - return arrayEqualInt64(l, r) + return arrayEqualFixedWidth(l, r) case *Uint8: r := right.(*Uint8) - return arrayEqualUint8(l, r) + return arrayEqualFixedWidth(l, r) case *Uint16: r := right.(*Uint16) - return arrayEqualUint16(l, r) + return arrayEqualFixedWidth(l, r) case *Uint32: r := right.(*Uint32) - return arrayEqualUint32(l, r) + return arrayEqualFixedWidth(l, r) case *Uint64: r := right.(*Uint64) - return arrayEqualUint64(l, r) + return arrayEqualFixedWidth(l, r) case *Float16: r := right.(*Float16) return arrayApproxEqualFloat16(l, r, opt) @@ -547,16 +547,16 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { return arrayEqualDecimal(l, r) case *Date32: r := right.(*Date32) - return arrayEqualDate32(l, r) + return arrayEqualFixedWidth(l, r) case *Date64: r := right.(*Date64) - return arrayEqualDate64(l, r) + return arrayEqualFixedWidth(l, r) case *Time32: r := right.(*Time32) - return arrayEqualTime32(l, r) + return arrayEqualFixedWidth(l, r) case *Time64: r := right.(*Time64) - return arrayEqualTime64(l, r) + return arrayEqualFixedWidth(l, r) case *Timestamp: r := right.(*Timestamp) return arrayEqualTimestamp(l, r) @@ -589,7 +589,7 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { return arrayEqualMonthDayNanoInterval(l, r) case *Duration: r := right.(*Duration) - return arrayEqualDuration(l, r) + return arrayEqualFixedWidth(l, r) case *Map: r := right.(*Map) if opt.unorderedMapKeys { diff --git a/arrow/array/concat.go b/arrow/array/concat.go index bb50354b..8f6aefbe 100644 --- a/arrow/array/concat.go +++ b/arrow/array/concat.go @@ -517,7 +517,9 @@ func concatListView(data []arrow.ArrayData, offsetType arrow.FixedWidthDataType, // concat is the implementation for actually performing the concatenation of the arrow.ArrayData // objects that we can call internally for nested types. func concat(data []arrow.ArrayData, mem memory.Allocator) (arr arrow.ArrayData, err error) { - out := &Data{refCount: 1, dtype: data[0].DataType(), nulls: 0} + out := &Data{dtype: data[0].DataType(), nulls: 0} + out.refCount.Add(1) + defer func() { if pErr := recover(); pErr != nil { err = utils.FormatRecoveredError("arrow/concat", pErr) diff --git a/arrow/array/data.go b/arrow/array/data.go index be75c7c7..62284b39 100644 --- a/arrow/array/data.go +++ b/arrow/array/data.go @@ -29,7 +29,7 @@ import ( // Data represents the memory and metadata of an Arrow array. type Data struct { - refCount int64 + refCount atomic.Int64 dtype arrow.DataType nulls int offset int @@ -56,8 +56,7 @@ func NewData(dtype arrow.DataType, length int, buffers []*memory.Buffer, childDa } } - return &Data{ - refCount: 1, + d := &Data{ dtype: dtype, nulls: nulls, length: length, @@ -65,6 +64,8 @@ func NewData(dtype arrow.DataType, length int, buffers []*memory.Buffer, childDa buffers: buffers, childData: childData, } + d.refCount.Add(1) + return d } // NewDataWithDictionary creates a new data object, but also sets the provided dictionary into the data if it's not nil @@ -129,16 +130,16 @@ func (d *Data) Reset(dtype arrow.DataType, length int, buffers []*memory.Buffer, // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (d *Data) Retain() { - atomic.AddInt64(&d.refCount, 1) + d.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (d *Data) Release() { - debug.Assert(atomic.LoadInt64(&d.refCount) > 0, "too many releases") + debug.Assert(d.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&d.refCount, -1) == 0 { + if d.refCount.Add(-1) == 0 { for _, b := range d.buffers { if b != nil { b.Release() @@ -246,7 +247,6 @@ func NewSliceData(data arrow.ArrayData, i, j int64) arrow.ArrayData { } o := &Data{ - refCount: 1, dtype: data.DataType(), nulls: UnknownNullCount, length: int(j - i), @@ -255,6 +255,7 @@ func NewSliceData(data arrow.ArrayData, i, j int64) arrow.ArrayData { childData: data.Children(), dictionary: data.(*Data).dictionary, } + o.refCount.Add(1) if data.NullN() == 0 { o.nulls = 0 diff --git a/arrow/array/decimal.go b/arrow/array/decimal.go index 1a9d61c1..dff0feaf 100644 --- a/arrow/array/decimal.go +++ b/arrow/array/decimal.go @@ -21,7 +21,6 @@ import ( "fmt" "reflect" "strings" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" @@ -45,7 +44,7 @@ func newDecimalData[T interface { decimal.Num[T] }](data arrow.ArrayData) *baseDecimal[T] { a := &baseDecimal[T]{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -148,11 +147,13 @@ func NewDecimal256Data(data arrow.ArrayData) *Decimal256 { return newDecimalData[decimal.Decimal256](data) } -type Decimal32Builder = baseDecimalBuilder[decimal.Decimal32] -type Decimal64Builder = baseDecimalBuilder[decimal.Decimal64] -type Decimal128Builder struct { - *baseDecimalBuilder[decimal.Decimal128] -} +type ( + Decimal32Builder = baseDecimalBuilder[decimal.Decimal32] + Decimal64Builder = baseDecimalBuilder[decimal.Decimal64] + Decimal128Builder struct { + *baseDecimalBuilder[decimal.Decimal128] + } +) func (b *Decimal128Builder) NewDecimal128Array() *Decimal128 { return b.NewDecimalArray() @@ -182,18 +183,20 @@ func newDecimalBuilder[T interface { decimal.DecimalTypes decimal.Num[T] }, DT arrow.DecimalType](mem memory.Allocator, dtype DT) *baseDecimalBuilder[T] { - return &baseDecimalBuilder[T]{ - builder: builder{refCount: 1, mem: mem}, + bdb := &baseDecimalBuilder[T]{ + builder: builder{mem: mem}, dtype: dtype, } + bdb.builder.refCount.Add(1) + return bdb } func (b *baseDecimalBuilder[T]) Type() arrow.DataType { return b.dtype } func (b *baseDecimalBuilder[T]) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -429,4 +432,9 @@ var ( _ Builder = (*Decimal64Builder)(nil) _ Builder = (*Decimal128Builder)(nil) _ Builder = (*Decimal256Builder)(nil) + + _ arrow.TypedArray[decimal.Decimal32] = (*Decimal32)(nil) + _ arrow.TypedArray[decimal.Decimal64] = (*Decimal64)(nil) + _ arrow.TypedArray[decimal.Decimal128] = (*Decimal128)(nil) + _ arrow.TypedArray[decimal.Decimal256] = (*Decimal256)(nil) ) diff --git a/arrow/array/dictionary.go b/arrow/array/dictionary.go index 0c23934a..4ddb5d4c 100644 --- a/arrow/array/dictionary.go +++ b/arrow/array/dictionary.go @@ -22,14 +22,11 @@ import ( "fmt" "math" "math/bits" - "sync/atomic" "unsafe" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" "github.com/apache/arrow-go/v18/arrow/decimal" - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/float16" "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/arrow/memory" @@ -66,7 +63,7 @@ type Dictionary struct { // and dictionary using the given type. func NewDictionaryArray(typ arrow.DataType, indices, dict arrow.Array) *Dictionary { a := &Dictionary{} - a.array.refCount = 1 + a.array.refCount.Add(1) dictdata := NewData(typ, indices.Len(), indices.Data().Buffers(), indices.Data().Children(), indices.NullN(), indices.Data().Offset()) dictdata.dictionary = dict.Data().(*Data) dict.Data().Retain() @@ -188,19 +185,19 @@ func NewValidatedDictionaryArray(typ *arrow.DictionaryType, indices, dict arrow. // an ArrayData object with a datatype of arrow.Dictionary and a dictionary func NewDictionaryData(data arrow.ArrayData) *Dictionary { a := &Dictionary{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } func (d *Dictionary) Retain() { - atomic.AddInt64(&d.refCount, 1) + d.refCount.Add(1) } func (d *Dictionary) Release() { - debug.Assert(atomic.LoadInt64(&d.refCount) > 0, "too many releases") + debug.Assert(d.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&d.refCount, -1) == 0 { + if d.refCount.Add(-1) == 0 { d.data.Release() d.data, d.nullBitmapBytes = nil, nil d.indices.Release() @@ -426,6 +423,73 @@ type dictionaryBuilder struct { idxBuilder IndexBuilder } +func createDictBuilder[T arrow.ValueType](mem memory.Allocator, idxbldr IndexBuilder, memo hashing.MemoTable, dt *arrow.DictionaryType, init arrow.Array) DictionaryBuilder { + ret := &dictBuilder[T]{ + dictionaryBuilder: dictionaryBuilder{ + builder: builder{mem: mem}, + idxBuilder: idxbldr, + memoTable: memo, + dt: dt, + }, + } + ret.builder.refCount.Add(1) + + if init != nil { + if err := ret.InsertDictValues(init.(arrValues[T])); err != nil { + panic(err) + } + } + return ret +} + +func createBinaryDictBuilder(mem memory.Allocator, idxbldr IndexBuilder, memo hashing.MemoTable, dt *arrow.DictionaryType, init arrow.Array) DictionaryBuilder { + ret := &BinaryDictionaryBuilder{ + dictionaryBuilder: dictionaryBuilder{ + builder: builder{mem: mem}, + idxBuilder: idxbldr, + memoTable: memo, + dt: dt, + }, + } + ret.builder.refCount.Add(1) + + if init != nil { + switch v := init.(type) { + case *String: + if err := ret.InsertStringDictValues(v); err != nil { + panic(err) + } + case *Binary: + if err := ret.InsertDictValues(v); err != nil { + panic(err) + } + } + } + return ret +} + +func createFixedSizeDictBuilder[T fsbType](mem memory.Allocator, idxbldr IndexBuilder, memo hashing.MemoTable, dt *arrow.DictionaryType, init arrow.Array) DictionaryBuilder { + var z T + ret := &fixedSizeDictionaryBuilder[T]{ + dictionaryBuilder: dictionaryBuilder{ + builder: builder{mem: mem}, + idxBuilder: idxbldr, + memoTable: memo, + dt: dt, + }, + byteWidth: int(unsafe.Sizeof(z)), + } + ret.builder.refCount.Add(1) + + if init != nil { + if err := ret.InsertDictValues(init.(arrValues[T])); err != nil { + panic(err) + } + } + + return ret +} + // NewDictionaryBuilderWithDict initializes a dictionary builder and inserts the values from `init` as the first // values in the dictionary, but does not insert them as values into the array. func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType, init arrow.Array) DictionaryBuilder { @@ -443,126 +507,55 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType panic(fmt.Errorf("arrow/array: unsupported builder for value type of %T", dt)) } - bldr := dictionaryBuilder{ - builder: builder{refCount: 1, mem: mem}, - idxBuilder: idxbldr, - memoTable: memo, - dt: dt, - } - switch dt.ValueType.ID() { case arrow.NULL: - ret := &NullDictionaryBuilder{bldr} + ret := &NullDictionaryBuilder{ + dictionaryBuilder: dictionaryBuilder{ + builder: builder{mem: mem}, + idxBuilder: idxbldr, + memoTable: memo, + dt: dt, + }, + } + ret.builder.refCount.Add(1) debug.Assert(init == nil, "arrow/array: doesn't make sense to init a null dictionary") return ret case arrow.UINT8: - ret := &Uint8DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Uint8)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[uint8](mem, idxbldr, memo, dt, init) case arrow.INT8: - ret := &Int8DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Int8)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[int8](mem, idxbldr, memo, dt, init) case arrow.UINT16: - ret := &Uint16DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Uint16)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[uint16](mem, idxbldr, memo, dt, init) case arrow.INT16: - ret := &Int16DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Int16)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[int16](mem, idxbldr, memo, dt, init) case arrow.UINT32: - ret := &Uint32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Uint32)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[uint32](mem, idxbldr, memo, dt, init) case arrow.INT32: - ret := &Int32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Int32)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[int32](mem, idxbldr, memo, dt, init) case arrow.UINT64: - ret := &Uint64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Uint64)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[uint64](mem, idxbldr, memo, dt, init) case arrow.INT64: - ret := &Int64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Int64)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[int64](mem, idxbldr, memo, dt, init) case arrow.FLOAT16: - ret := &Float16DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Float16)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[float16.Num](mem, idxbldr, memo, dt, init) case arrow.FLOAT32: - ret := &Float32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Float32)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[float32](mem, idxbldr, memo, dt, init) case arrow.FLOAT64: - ret := &Float64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Float64)); err != nil { - panic(err) - } - } - return ret - case arrow.STRING: - ret := &BinaryDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertStringDictValues(init.(*String)); err != nil { - panic(err) - } - } - return ret - case arrow.BINARY: - ret := &BinaryDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Binary)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[float64](mem, idxbldr, memo, dt, init) + case arrow.STRING, arrow.BINARY: + return createBinaryDictBuilder(mem, idxbldr, memo, dt, init) case arrow.FIXED_SIZE_BINARY: ret := &FixedSizeBinaryDictionaryBuilder{ - bldr, dt.ValueType.(*arrow.FixedSizeBinaryType).ByteWidth, + dictionaryBuilder: dictionaryBuilder{ + builder: builder{mem: mem}, + idxBuilder: idxbldr, + memoTable: memo, + dt: dt, + }, + byteWidth: dt.ValueType.(*arrow.FixedSizeBinaryType).ByteWidth, } + ret.builder.refCount.Add(1) + if init != nil { if err = ret.InsertDictValues(init.(*FixedSizeBinary)); err != nil { panic(err) @@ -570,93 +563,27 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType } return ret case arrow.DATE32: - ret := &Date32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Date32)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[arrow.Date32](mem, idxbldr, memo, dt, init) case arrow.DATE64: - ret := &Date64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Date64)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[arrow.Date64](mem, idxbldr, memo, dt, init) case arrow.TIMESTAMP: - ret := &TimestampDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Timestamp)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[arrow.Timestamp](mem, idxbldr, memo, dt, init) case arrow.TIME32: - ret := &Time32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Time32)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[arrow.Time32](mem, idxbldr, memo, dt, init) case arrow.TIME64: - ret := &Time64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Time64)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[arrow.Time64](mem, idxbldr, memo, dt, init) case arrow.INTERVAL_MONTHS: - ret := &MonthIntervalDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*MonthInterval)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[arrow.MonthInterval](mem, idxbldr, memo, dt, init) case arrow.INTERVAL_DAY_TIME: - ret := &DayTimeDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*DayTimeInterval)); err != nil { - panic(err) - } - } - return ret + return createFixedSizeDictBuilder[arrow.DayTimeInterval](mem, idxbldr, memo, dt, init) case arrow.DECIMAL32: - ret := &Decimal32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Decimal32)); err != nil { - panic(err) - } - } - return ret + return createFixedSizeDictBuilder[decimal.Decimal32](mem, idxbldr, memo, dt, init) case arrow.DECIMAL64: - ret := &Decimal64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Decimal64)); err != nil { - panic(err) - } - } - return ret + return createFixedSizeDictBuilder[decimal.Decimal64](mem, idxbldr, memo, dt, init) case arrow.DECIMAL128: - ret := &Decimal128DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Decimal128)); err != nil { - panic(err) - } - } - return ret + return createFixedSizeDictBuilder[decimal.Decimal128](mem, idxbldr, memo, dt, init) case arrow.DECIMAL256: - ret := &Decimal256DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Decimal256)); err != nil { - panic(err) - } - } - return ret + return createFixedSizeDictBuilder[decimal.Decimal256](mem, idxbldr, memo, dt, init) case arrow.LIST: case arrow.STRUCT: case arrow.SPARSE_UNION: @@ -666,24 +593,12 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType case arrow.EXTENSION: case arrow.FIXED_SIZE_LIST: case arrow.DURATION: - ret := &DurationDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Duration)); err != nil { - panic(err) - } - } - return ret + return createDictBuilder[arrow.Duration](mem, idxbldr, memo, dt, init) case arrow.LARGE_STRING: case arrow.LARGE_BINARY: case arrow.LARGE_LIST: case arrow.INTERVAL_MONTH_DAY_NANO: - ret := &MonthDayNanoDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*MonthDayNanoInterval)); err != nil { - panic(err) - } - } - return ret + return createFixedSizeDictBuilder[arrow.MonthDayNanoInterval](mem, idxbldr, memo, dt, init) } panic("arrow/array: unimplemented dictionary key type") @@ -696,9 +611,9 @@ func NewDictionaryBuilder(mem memory.Allocator, dt *arrow.DictionaryType) Dictio func (b *dictionaryBuilder) Type() arrow.DataType { return b.dt } func (b *dictionaryBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { b.idxBuilder.Release() b.idxBuilder.Builder = nil if binmemo, ok := b.memoTable.(*hashing.BinaryMemoTable); ok { @@ -820,7 +735,7 @@ func (b *dictionaryBuilder) newData() *Data { func (b *dictionaryBuilder) NewDictionaryArray() *Dictionary { a := &Dictionary{} - a.refCount = 1 + a.refCount.Add(1) indices := b.newData() a.setData(indices) @@ -1071,27 +986,20 @@ func (b *NullDictionaryBuilder) AppendArray(arr arrow.Array) error { return nil } -type Int8DictionaryBuilder struct { +type dictBuilder[T arrow.ValueType] struct { dictionaryBuilder } -func (b *Int8DictionaryBuilder) Append(v int8) error { return b.appendValue(v) } -func (b *Int8DictionaryBuilder) InsertDictValues(arr *Int8) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return +func (b *dictBuilder[T]) Append(v T) error { + return b.appendValue(v) } -type Uint8DictionaryBuilder struct { - dictionaryBuilder +type arrValues[T arrow.ValueType] interface { + Values() []T } -func (b *Uint8DictionaryBuilder) Append(v uint8) error { return b.appendValue(v) } -func (b *Uint8DictionaryBuilder) InsertDictValues(arr *Uint8) (err error) { - for _, v := range arr.values { +func (b *dictBuilder[T]) InsertDictValues(arr arrValues[T]) (err error) { + for _, v := range arr.Values() { if err = b.insertDictValue(v); err != nil { break } @@ -1099,231 +1007,30 @@ func (b *Uint8DictionaryBuilder) InsertDictValues(arr *Uint8) (err error) { return } -type Int16DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Int16DictionaryBuilder) Append(v int16) error { return b.appendValue(v) } -func (b *Int16DictionaryBuilder) InsertDictValues(arr *Int16) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return -} - -type Uint16DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Uint16DictionaryBuilder) Append(v uint16) error { return b.appendValue(v) } -func (b *Uint16DictionaryBuilder) InsertDictValues(arr *Uint16) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return -} - -type Int32DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Int32DictionaryBuilder) Append(v int32) error { return b.appendValue(v) } -func (b *Int32DictionaryBuilder) InsertDictValues(arr *Int32) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return -} - -type Uint32DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Uint32DictionaryBuilder) Append(v uint32) error { return b.appendValue(v) } -func (b *Uint32DictionaryBuilder) InsertDictValues(arr *Uint32) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return -} - -type Int64DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Int64DictionaryBuilder) Append(v int64) error { return b.appendValue(v) } -func (b *Int64DictionaryBuilder) InsertDictValues(arr *Int64) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return -} - -type Uint64DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Uint64DictionaryBuilder) Append(v uint64) error { return b.appendValue(v) } -func (b *Uint64DictionaryBuilder) InsertDictValues(arr *Uint64) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return -} - -type DurationDictionaryBuilder struct { - dictionaryBuilder -} - -func (b *DurationDictionaryBuilder) Append(v arrow.Duration) error { return b.appendValue(int64(v)) } -func (b *DurationDictionaryBuilder) InsertDictValues(arr *Duration) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(int64(v)); err != nil { - break - } - } - return -} - -type TimestampDictionaryBuilder struct { - dictionaryBuilder -} - -func (b *TimestampDictionaryBuilder) Append(v arrow.Timestamp) error { return b.appendValue(int64(v)) } -func (b *TimestampDictionaryBuilder) InsertDictValues(arr *Timestamp) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(int64(v)); err != nil { - break - } - } - return -} - -type Time32DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Time32DictionaryBuilder) Append(v arrow.Time32) error { return b.appendValue(int32(v)) } -func (b *Time32DictionaryBuilder) InsertDictValues(arr *Time32) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(int32(v)); err != nil { - break - } - } - return -} - -type Time64DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Time64DictionaryBuilder) Append(v arrow.Time64) error { return b.appendValue(int64(v)) } -func (b *Time64DictionaryBuilder) InsertDictValues(arr *Time64) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(int64(v)); err != nil { - break - } - } - return -} - -type Date32DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Date32DictionaryBuilder) Append(v arrow.Date32) error { return b.appendValue(int32(v)) } -func (b *Date32DictionaryBuilder) InsertDictValues(arr *Date32) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(int32(v)); err != nil { - break - } - } - return -} - -type Date64DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Date64DictionaryBuilder) Append(v arrow.Date64) error { return b.appendValue(int64(v)) } -func (b *Date64DictionaryBuilder) InsertDictValues(arr *Date64) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(int64(v)); err != nil { - break - } - } - return -} - -type MonthIntervalDictionaryBuilder struct { - dictionaryBuilder -} - -func (b *MonthIntervalDictionaryBuilder) Append(v arrow.MonthInterval) error { - return b.appendValue(int32(v)) -} -func (b *MonthIntervalDictionaryBuilder) InsertDictValues(arr *MonthInterval) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(int32(v)); err != nil { - break - } - } - return -} - -type Float16DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Float16DictionaryBuilder) Append(v float16.Num) error { return b.appendValue(v.Uint16()) } -func (b *Float16DictionaryBuilder) InsertDictValues(arr *Float16) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v.Uint16()); err != nil { - break - } - } - return -} - -type Float32DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Float32DictionaryBuilder) Append(v float32) error { return b.appendValue(v) } -func (b *Float32DictionaryBuilder) InsertDictValues(arr *Float32) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return -} - -type Float64DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Float64DictionaryBuilder) Append(v float64) error { return b.appendValue(v) } -func (b *Float64DictionaryBuilder) InsertDictValues(arr *Float64) (err error) { - for _, v := range arr.values { - if err = b.insertDictValue(v); err != nil { - break - } - } - return -} +type Int8DictionaryBuilder = dictBuilder[int8] +type Uint8DictionaryBuilder = dictBuilder[uint8] +type Int16DictionaryBuilder = dictBuilder[int16] +type Uint16DictionaryBuilder = dictBuilder[uint16] +type Int32DictionaryBuilder = dictBuilder[int32] +type Uint32DictionaryBuilder = dictBuilder[uint32] +type Int64DictionaryBuilder = dictBuilder[int64] +type Uint64DictionaryBuilder = dictBuilder[uint64] +type Float16DictionaryBuilder = dictBuilder[float16.Num] +type Float32DictionaryBuilder = dictBuilder[float32] +type Float64DictionaryBuilder = dictBuilder[float64] +type DurationDictionaryBuilder = dictBuilder[arrow.Duration] +type TimestampDictionaryBuilder = dictBuilder[arrow.Timestamp] +type Time32DictionaryBuilder = dictBuilder[arrow.Time32] +type Time64DictionaryBuilder = dictBuilder[arrow.Time64] +type Date32DictionaryBuilder = dictBuilder[arrow.Date32] +type Date64DictionaryBuilder = dictBuilder[arrow.Date64] +type MonthIntervalDictionaryBuilder = dictBuilder[arrow.MonthInterval] +type DayTimeDictionaryBuilder = fixedSizeDictionaryBuilder[arrow.DayTimeInterval] +type Decimal32DictionaryBuilder = fixedSizeDictionaryBuilder[decimal.Decimal32] +type Decimal64DictionaryBuilder = fixedSizeDictionaryBuilder[decimal.Decimal64] +type Decimal128DictionaryBuilder = fixedSizeDictionaryBuilder[decimal.Decimal128] +type Decimal256DictionaryBuilder = fixedSizeDictionaryBuilder[decimal.Decimal256] +type MonthDayNanoDictionaryBuilder = fixedSizeDictionaryBuilder[arrow.MonthDayNanoInterval] type BinaryDictionaryBuilder struct { dictionaryBuilder @@ -1351,6 +1058,7 @@ func (b *BinaryDictionaryBuilder) InsertDictValues(arr *Binary) (err error) { } return } + func (b *BinaryDictionaryBuilder) InsertStringDictValues(arr *String) (err error) { if !arrow.TypeEqual(arr.DataType(), b.dt.ValueType) { return fmt.Errorf("dictionary insert type mismatch: cannot insert values of type %T to dictionary type %T", arr.DataType(), b.dt.ValueType) @@ -1399,133 +1107,61 @@ func (b *BinaryDictionaryBuilder) ValueStr(i int) string { return string(b.Value(i)) } -type FixedSizeBinaryDictionaryBuilder struct { - dictionaryBuilder - byteWidth int +type fsbType interface { + arrow.DayTimeInterval | arrow.MonthDayNanoInterval | + decimal.Decimal32 | decimal.Decimal64 | decimal.Decimal128 | decimal.Decimal256 } -func (b *FixedSizeBinaryDictionaryBuilder) Append(v []byte) error { - return b.appendValue(v[:b.byteWidth]) -} -func (b *FixedSizeBinaryDictionaryBuilder) InsertDictValues(arr *FixedSizeBinary) (err error) { - var ( - beg = arr.array.data.offset * b.byteWidth - end = (arr.array.data.offset + arr.data.length) * b.byteWidth - ) - data := arr.valueBytes[beg:end] - for len(data) > 0 { - if err = b.insertDictValue(data[:b.byteWidth]); err != nil { - break - } - data = data[b.byteWidth:] - } - return -} - -type Decimal32DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Decimal32DictionaryBuilder) Append(v decimal.Decimal32) error { - return b.appendValue((*(*[arrow.Decimal32SizeBytes]byte)(unsafe.Pointer(&v)))[:]) -} -func (b *Decimal32DictionaryBuilder) InsertDictValues(arr *Decimal32) (err error) { - data := arrow.Decimal32Traits.CastToBytes(arr.values) - for len(data) > 0 { - if err = b.insertDictValue(data[:arrow.Decimal32SizeBytes]); err != nil { - break - } - data = data[arrow.Decimal32SizeBytes:] - } - return -} - -type Decimal64DictionaryBuilder struct { - dictionaryBuilder -} - -func (b *Decimal64DictionaryBuilder) Append(v decimal.Decimal64) error { - return b.appendValue((*(*[arrow.Decimal64SizeBytes]byte)(unsafe.Pointer(&v)))[:]) -} -func (b *Decimal64DictionaryBuilder) InsertDictValues(arr *Decimal64) (err error) { - data := arrow.Decimal64Traits.CastToBytes(arr.values) - for len(data) > 0 { - if err = b.insertDictValue(data[:arrow.Decimal64SizeBytes]); err != nil { - break - } - data = data[arrow.Decimal64SizeBytes:] - } - return -} - -type Decimal128DictionaryBuilder struct { +type fixedSizeDictionaryBuilder[T fsbType] struct { dictionaryBuilder + byteWidth int } -func (b *Decimal128DictionaryBuilder) Append(v decimal128.Num) error { - return b.appendValue((*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&v)))[:]) -} -func (b *Decimal128DictionaryBuilder) InsertDictValues(arr *Decimal128) (err error) { - data := arrow.Decimal128Traits.CastToBytes(arr.values) - for len(data) > 0 { - if err = b.insertDictValue(data[:arrow.Decimal128SizeBytes]); err != nil { - break - } - data = data[arrow.Decimal128SizeBytes:] +func (b *fixedSizeDictionaryBuilder[T]) Append(v T) error { + if v, ok := any(v).([]byte); ok { + return b.appendBytes(v[:b.byteWidth]) } - return -} -type Decimal256DictionaryBuilder struct { - dictionaryBuilder + sliceHdr := struct { + Addr *T + Len int + Cap int + }{&v, b.byteWidth, b.byteWidth} + slice := *(*[]byte)(unsafe.Pointer(&sliceHdr)) + return b.appendValue(slice) } -func (b *Decimal256DictionaryBuilder) Append(v decimal256.Num) error { - return b.appendValue((*(*[arrow.Decimal256SizeBytes]byte)(unsafe.Pointer(&v)))[:]) -} -func (b *Decimal256DictionaryBuilder) InsertDictValues(arr *Decimal256) (err error) { - data := arrow.Decimal256Traits.CastToBytes(arr.values) +func (b *fixedSizeDictionaryBuilder[T]) InsertDictValues(arr arrValues[T]) (err error) { + data := arrow.GetBytes(arr.Values()) for len(data) > 0 { - if err = b.insertDictValue(data[:arrow.Decimal256SizeBytes]); err != nil { + if err = b.insertDictBytes(data[:b.byteWidth]); err != nil { break } - data = data[arrow.Decimal256SizeBytes:] + data = data[b.byteWidth:] } return } -type MonthDayNanoDictionaryBuilder struct { +type FixedSizeBinaryDictionaryBuilder struct { dictionaryBuilder + byteWidth int } -func (b *MonthDayNanoDictionaryBuilder) Append(v arrow.MonthDayNanoInterval) error { - return b.appendValue((*(*[arrow.MonthDayNanoIntervalSizeBytes]byte)(unsafe.Pointer(&v)))[:]) -} -func (b *MonthDayNanoDictionaryBuilder) InsertDictValues(arr *MonthDayNanoInterval) (err error) { - data := arrow.MonthDayNanoIntervalTraits.CastToBytes(arr.values) - for len(data) > 0 { - if err = b.insertDictValue(data[:arrow.MonthDayNanoIntervalSizeBytes]); err != nil { - break - } - data = data[arrow.MonthDayNanoIntervalSizeBytes:] - } - return -} - -type DayTimeDictionaryBuilder struct { - dictionaryBuilder +func (b *FixedSizeBinaryDictionaryBuilder) Append(v []byte) error { + return b.appendValue(v[:b.byteWidth]) } -func (b *DayTimeDictionaryBuilder) Append(v arrow.DayTimeInterval) error { - return b.appendValue((*(*[arrow.DayTimeIntervalSizeBytes]byte)(unsafe.Pointer(&v)))[:]) -} -func (b *DayTimeDictionaryBuilder) InsertDictValues(arr *DayTimeInterval) (err error) { - data := arrow.DayTimeIntervalTraits.CastToBytes(arr.values) +func (b *FixedSizeBinaryDictionaryBuilder) InsertDictValues(arr *FixedSizeBinary) (err error) { + var ( + beg = arr.array.data.offset * b.byteWidth + end = (arr.array.data.offset + arr.data.length) * b.byteWidth + ) + data := arr.valueBytes[beg:end] for len(data) > 0 { - if err = b.insertDictValue(data[:arrow.DayTimeIntervalSizeBytes]); err != nil { + if err = b.insertDictValue(data[:b.byteWidth]); err != nil { break } - data = data[arrow.DayTimeIntervalSizeBytes:] + data = data[b.byteWidth:] } return } diff --git a/arrow/array/encoded.go b/arrow/array/encoded.go index 81c375c9..8e39090f 100644 --- a/arrow/array/encoded.go +++ b/arrow/array/encoded.go @@ -21,7 +21,6 @@ import ( "fmt" "math" "reflect" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/encoded" @@ -50,7 +49,7 @@ func NewRunEndEncodedArray(runEnds, values arrow.Array, logicalLength, offset in func NewRunEndEncodedData(data arrow.ArrayData) *RunEndEncoded { r := &RunEndEncoded{} - r.refCount = 1 + r.refCount.Add(1) r.setData(data.(*Data)) return r } @@ -305,14 +304,16 @@ func NewRunEndEncodedBuilder(mem memory.Allocator, runEnds, encoded arrow.DataTy case arrow.INT64: maxEnd = math.MaxInt64 } - return &RunEndEncodedBuilder{ - builder: builder{refCount: 1, mem: mem}, + reb := &RunEndEncodedBuilder{ + builder: builder{mem: mem}, dt: dt, runEnds: NewBuilder(mem, runEnds), values: NewBuilder(mem, encoded), maxRunEnd: maxEnd, lastUnmarshalled: nil, } + reb.builder.refCount.Add(1) + return reb } func (b *RunEndEncodedBuilder) Type() arrow.DataType { @@ -320,9 +321,9 @@ func (b *RunEndEncodedBuilder) Type() arrow.DataType { } func (b *RunEndEncodedBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { b.values.Release() b.runEnds.Release() } diff --git a/arrow/array/extension.go b/arrow/array/extension.go index d1a28350..e509b5e0 100644 --- a/arrow/array/extension.go +++ b/arrow/array/extension.go @@ -86,7 +86,7 @@ func NewExtensionArrayWithStorage(dt arrow.ExtensionType, storage arrow.Array) a // underlying data built for the storage array. func NewExtensionData(data arrow.ArrayData) ExtensionArray { base := ExtensionArrayBase{} - base.refCount = 1 + base.refCount.Add(1) base.setData(data.(*Data)) // use the ExtensionType's ArrayType to construct the correctly typed object @@ -173,7 +173,7 @@ func (e *ExtensionArrayBase) ValueStr(i int) string { } // no-op function that exists simply to force embedding this in any extension array types. -func (ExtensionArrayBase) mustEmbedExtensionArrayBase() {} +func (*ExtensionArrayBase) mustEmbedExtensionArrayBase() {} // ExtensionBuilder is a convenience builder so that NewBuilder and such will still work // with extension types properly. Depending on preference it may be cleaner or easier to just use diff --git a/arrow/array/fixed_size_list.go b/arrow/array/fixed_size_list.go index 84036f94..4a0524ec 100644 --- a/arrow/array/fixed_size_list.go +++ b/arrow/array/fixed_size_list.go @@ -20,7 +20,6 @@ import ( "bytes" "fmt" "strings" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" @@ -41,7 +40,7 @@ var _ ListLike = (*FixedSizeList)(nil) // NewFixedSizeListData returns a new List array value, from data. func NewFixedSizeListData(data arrow.ArrayData) *FixedSizeList { a := &FixedSizeList{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -54,6 +53,7 @@ func (a *FixedSizeList) ValueStr(i int) string { } return string(a.GetOneForMarshal(i).(json.RawMessage)) } + func (a *FixedSizeList) String() string { o := new(strings.Builder) o.WriteString("[") @@ -169,28 +169,33 @@ type FixedSizeListBuilder struct { // NewFixedSizeListBuilder returns a builder, using the provided memory allocator. // The created list builder will create a list whose elements will be of type etype. func NewFixedSizeListBuilder(mem memory.Allocator, n int32, etype arrow.DataType) *FixedSizeListBuilder { - return &FixedSizeListBuilder{ + fslb := &FixedSizeListBuilder{ baseListBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, etype), dt: arrow.FixedSizeListOf(n, etype), }, n, } + fslb.baseListBuilder.builder.refCount.Add(1) + return fslb } // NewFixedSizeListBuilderWithField returns a builder similarly to // NewFixedSizeListBuilder, but it accepts a child rather than just a datatype // to ensure nullability context is preserved. func NewFixedSizeListBuilderWithField(mem memory.Allocator, n int32, field arrow.Field) *FixedSizeListBuilder { - return &FixedSizeListBuilder{ + fslb := &FixedSizeListBuilder{ baseListBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, field.Type), dt: arrow.FixedSizeListOfField(n, field), }, n, } + + fslb.baseListBuilder.builder.refCount.Add(1) + return fslb } func (b *FixedSizeListBuilder) Type() arrow.DataType { return b.dt } @@ -198,9 +203,9 @@ func (b *FixedSizeListBuilder) Type() arrow.DataType { return b.dt } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *FixedSizeListBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil diff --git a/arrow/array/fixedsize_binary.go b/arrow/array/fixedsize_binary.go index 7049c9c0..a3b03806 100644 --- a/arrow/array/fixedsize_binary.go +++ b/arrow/array/fixedsize_binary.go @@ -37,7 +37,7 @@ type FixedSizeBinary struct { // NewFixedSizeBinaryData constructs a new fixed-size binary array from data. func NewFixedSizeBinaryData(data arrow.ArrayData) *FixedSizeBinary { a := &FixedSizeBinary{bytewidth: int32(data.DataType().(arrow.FixedWidthDataType).BitWidth() / 8)} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -52,6 +52,7 @@ func (a *FixedSizeBinary) Value(i int) []byte { ) return a.valueBytes[beg:end] } + func (a *FixedSizeBinary) ValueStr(i int) string { if a.IsNull(i) { return NullValueStr @@ -83,7 +84,6 @@ func (a *FixedSizeBinary) setData(data *Data) { if vals != nil { a.valueBytes = vals.Bytes() } - } func (a *FixedSizeBinary) GetOneForMarshal(i int) interface{} { @@ -118,6 +118,4 @@ func arrayEqualFixedSizeBinary(left, right *FixedSizeBinary) bool { return true } -var ( - _ arrow.Array = (*FixedSizeBinary)(nil) -) +var _ arrow.Array = (*FixedSizeBinary)(nil) diff --git a/arrow/array/fixedsize_binarybuilder.go b/arrow/array/fixedsize_binarybuilder.go index 02e72a25..ee7869fa 100644 --- a/arrow/array/fixedsize_binarybuilder.go +++ b/arrow/array/fixedsize_binarybuilder.go @@ -21,7 +21,6 @@ import ( "encoding/base64" "fmt" "reflect" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/internal/debug" @@ -39,10 +38,11 @@ type FixedSizeBinaryBuilder struct { func NewFixedSizeBinaryBuilder(mem memory.Allocator, dtype *arrow.FixedSizeBinaryType) *FixedSizeBinaryBuilder { b := &FixedSizeBinaryBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, dtype: dtype, values: newByteBufferBuilder(mem), } + b.builder.refCount.Add(1) return b } @@ -52,9 +52,9 @@ func (b *FixedSizeBinaryBuilder) Type() arrow.DataType { return b.dtype } // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (b *FixedSizeBinaryBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -256,6 +256,4 @@ func (b *FixedSizeBinaryBuilder) UnmarshalJSON(data []byte) error { return b.Unmarshal(dec) } -var ( - _ Builder = (*FixedSizeBinaryBuilder)(nil) -) +var _ Builder = (*FixedSizeBinaryBuilder)(nil) diff --git a/arrow/array/float16.go b/arrow/array/float16.go index 6b0e820f..5f57f725 100644 --- a/arrow/array/float16.go +++ b/arrow/array/float16.go @@ -33,7 +33,7 @@ type Float16 struct { func NewFloat16Data(data arrow.ArrayData) *Float16 { a := &Float16{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -106,18 +106,7 @@ func (a *Float16) MarshalJSON() ([]byte, error) { return json.Marshal(vals) } -func arrayEqualFloat16(left, right *Float16) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - var ( - _ arrow.Array = (*Float16)(nil) + _ arrow.Array = (*Float16)(nil) + _ arrow.TypedArray[float16.Num] = (*Float16)(nil) ) diff --git a/arrow/array/float16_builder.go b/arrow/array/float16_builder.go index 93dbfbc0..d4acd7f6 100644 --- a/arrow/array/float16_builder.go +++ b/arrow/array/float16_builder.go @@ -21,7 +21,6 @@ import ( "fmt" "reflect" "strconv" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" @@ -39,7 +38,9 @@ type Float16Builder struct { } func NewFloat16Builder(mem memory.Allocator) *Float16Builder { - return &Float16Builder{builder: builder{refCount: 1, mem: mem}} + fb := &Float16Builder{builder: builder{mem: mem}} + fb.refCount.Add(1) + return fb } func (b *Float16Builder) Type() arrow.DataType { return arrow.FixedWidthTypes.Float16 } @@ -47,9 +48,9 @@ func (b *Float16Builder) Type() arrow.DataType { return arrow.FixedWidthTypes.Fl // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Float16Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil diff --git a/arrow/array/interval.go b/arrow/array/interval.go index 324647e8..54915cdd 100644 --- a/arrow/array/interval.go +++ b/arrow/array/interval.go @@ -21,7 +21,6 @@ import ( "fmt" "strconv" "strings" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" @@ -51,7 +50,7 @@ type MonthInterval struct { func NewMonthIntervalData(data arrow.ArrayData) *MonthInterval { a := &MonthInterval{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -63,7 +62,8 @@ func (a *MonthInterval) ValueStr(i int) string { } return fmt.Sprintf("%v", a.Value(i)) } -func (a *MonthInterval) MonthIntervalValues() []arrow.MonthInterval { return a.values } +func (a *MonthInterval) MonthIntervalValues() []arrow.MonthInterval { return a.Values() } +func (a *MonthInterval) Values() []arrow.MonthInterval { return a.values } func (a *MonthInterval) String() string { o := new(strings.Builder) @@ -140,7 +140,9 @@ type MonthIntervalBuilder struct { } func NewMonthIntervalBuilder(mem memory.Allocator) *MonthIntervalBuilder { - return &MonthIntervalBuilder{builder: builder{refCount: 1, mem: mem}} + mib := &MonthIntervalBuilder{builder: builder{mem: mem}} + mib.refCount.Add(1) + return mib } func (b *MonthIntervalBuilder) Type() arrow.DataType { return arrow.FixedWidthTypes.MonthInterval } @@ -148,9 +150,9 @@ func (b *MonthIntervalBuilder) Type() arrow.DataType { return arrow.FixedWidthTy // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *MonthIntervalBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -348,7 +350,7 @@ type DayTimeInterval struct { func NewDayTimeIntervalData(data arrow.ArrayData) *DayTimeInterval { a := &DayTimeInterval{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -440,7 +442,9 @@ type DayTimeIntervalBuilder struct { } func NewDayTimeIntervalBuilder(mem memory.Allocator) *DayTimeIntervalBuilder { - return &DayTimeIntervalBuilder{builder: builder{refCount: 1, mem: mem}} + dtb := &DayTimeIntervalBuilder{builder: builder{mem: mem}} + dtb.refCount.Add(1) + return dtb } func (b *DayTimeIntervalBuilder) Type() arrow.DataType { return arrow.FixedWidthTypes.DayTimeInterval } @@ -448,9 +452,9 @@ func (b *DayTimeIntervalBuilder) Type() arrow.DataType { return arrow.FixedWidth // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *DayTimeIntervalBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -647,7 +651,7 @@ type MonthDayNanoInterval struct { func NewMonthDayNanoIntervalData(data arrow.ArrayData) *MonthDayNanoInterval { a := &MonthDayNanoInterval{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -741,7 +745,9 @@ type MonthDayNanoIntervalBuilder struct { } func NewMonthDayNanoIntervalBuilder(mem memory.Allocator) *MonthDayNanoIntervalBuilder { - return &MonthDayNanoIntervalBuilder{builder: builder{refCount: 1, mem: mem}} + mb := &MonthDayNanoIntervalBuilder{builder: builder{mem: mem}} + mb.refCount.Add(1) + return mb } func (b *MonthDayNanoIntervalBuilder) Type() arrow.DataType { @@ -751,9 +757,9 @@ func (b *MonthDayNanoIntervalBuilder) Type() arrow.DataType { // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *MonthDayNanoIntervalBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -950,4 +956,8 @@ var ( _ Builder = (*MonthIntervalBuilder)(nil) _ Builder = (*DayTimeIntervalBuilder)(nil) _ Builder = (*MonthDayNanoIntervalBuilder)(nil) + + _ arrow.TypedArray[arrow.MonthInterval] = (*MonthInterval)(nil) + _ arrow.TypedArray[arrow.DayTimeInterval] = (*DayTimeInterval)(nil) + _ arrow.TypedArray[arrow.MonthDayNanoInterval] = (*MonthDayNanoInterval)(nil) ) diff --git a/arrow/array/json_reader.go b/arrow/array/json_reader.go index 7835b280..b0698b3a 100644 --- a/arrow/array/json_reader.go +++ b/arrow/array/json_reader.go @@ -28,8 +28,10 @@ import ( "github.com/apache/arrow-go/v18/internal/json" ) -type Option func(config) -type config interface{} +type ( + Option func(config) + config interface{} +) // WithChunk sets the chunk size for reading in json records. The default is to // read in one row per record batch as a single object. If chunk size is set to @@ -72,7 +74,7 @@ type JSONReader struct { bldr *RecordBuilder - refs int64 + refs atomic.Int64 cur arrow.Record err error @@ -93,9 +95,10 @@ func NewJSONReader(r io.Reader, schema *arrow.Schema, opts ...Option) *JSONReade rr := &JSONReader{ r: json.NewDecoder(r), schema: schema, - refs: 1, chunk: 1, } + rr.refs.Add(1) + for _, o := range opts { o(rr) } @@ -126,13 +129,13 @@ func (r *JSONReader) Schema() *arrow.Schema { return r.schema } func (r *JSONReader) Record() arrow.Record { return r.cur } func (r *JSONReader) Retain() { - atomic.AddInt64(&r.refs, 1) + r.refs.Add(1) } func (r *JSONReader) Release() { - debug.Assert(atomic.LoadInt64(&r.refs) > 0, "too many releases") + debug.Assert(r.refs.Load() > 0, "too many releases") - if atomic.AddInt64(&r.refs, -1) == 0 { + if r.refs.Add(-1) == 0 { if r.cur != nil { r.cur.Release() r.bldr.Release() @@ -186,7 +189,7 @@ func (r *JSONReader) next1() bool { } func (r *JSONReader) nextn() bool { - var n = 0 + n := 0 for i := 0; i < r.chunk && !r.done; i, n = i+1, n+1 { if !r.readNext() { @@ -200,6 +203,4 @@ func (r *JSONReader) nextn() bool { return n > 0 } -var ( - _ RecordReader = (*JSONReader)(nil) -) +var _ RecordReader = (*JSONReader)(nil) diff --git a/arrow/array/list.go b/arrow/array/list.go index e80bc896..806b89c9 100644 --- a/arrow/array/list.go +++ b/arrow/array/list.go @@ -20,7 +20,6 @@ import ( "bytes" "fmt" "strings" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" @@ -51,7 +50,7 @@ var _ ListLike = (*List)(nil) // NewListData returns a new List array value, from data. func NewListData(data arrow.ArrayData) *List { a := &List{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -183,7 +182,7 @@ var _ ListLike = (*LargeList)(nil) // NewLargeListData returns a new LargeList array value, from data. func NewLargeListData(data arrow.ArrayData) *LargeList { a := new(LargeList) - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -337,30 +336,34 @@ type LargeListBuilder struct { // The created list builder will create a list whose elements will be of type etype. func NewListBuilder(mem memory.Allocator, etype arrow.DataType) *ListBuilder { offsetBldr := NewInt32Builder(mem) - return &ListBuilder{ + lb := &ListBuilder{ baseListBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, etype), offsets: offsetBldr, dt: arrow.ListOf(etype), appendOffsetVal: func(o int) { offsetBldr.Append(int32(o)) }, }, } + lb.refCount.Add(1) + return lb } // NewListBuilderWithField takes a field to use for the child rather than just // a datatype to allow for more customization. func NewListBuilderWithField(mem memory.Allocator, field arrow.Field) *ListBuilder { offsetBldr := NewInt32Builder(mem) - return &ListBuilder{ + lb := &ListBuilder{ baseListBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, field.Type), offsets: offsetBldr, dt: arrow.ListOfField(field), appendOffsetVal: func(o int) { offsetBldr.Append(int32(o)) }, }, } + lb.refCount.Add(1) + return lb } func (b *baseListBuilder) Type() arrow.DataType { @@ -381,38 +384,42 @@ func (b *baseListBuilder) Type() arrow.DataType { // The created list builder will create a list whose elements will be of type etype. func NewLargeListBuilder(mem memory.Allocator, etype arrow.DataType) *LargeListBuilder { offsetBldr := NewInt64Builder(mem) - return &LargeListBuilder{ + llb := &LargeListBuilder{ baseListBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, etype), offsets: offsetBldr, dt: arrow.LargeListOf(etype), appendOffsetVal: func(o int) { offsetBldr.Append(int64(o)) }, }, } + llb.refCount.Add(1) + return llb } // NewLargeListBuilderWithField takes a field rather than just an element type // to allow for more customization of the final type of the LargeList Array func NewLargeListBuilderWithField(mem memory.Allocator, field arrow.Field) *LargeListBuilder { offsetBldr := NewInt64Builder(mem) - return &LargeListBuilder{ + llb := &LargeListBuilder{ baseListBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, field.Type), offsets: offsetBldr, dt: arrow.LargeListOfField(field), appendOffsetVal: func(o int) { offsetBldr.Append(int64(o)) }, }, } + llb.refCount.Add(1) + return llb } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *baseListBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -420,7 +427,6 @@ func (b *baseListBuilder) Release() { b.values.Release() b.offsets.Release() } - } func (b *baseListBuilder) appendNextOffset() { @@ -646,7 +652,7 @@ var _ VarLenListLike = (*ListView)(nil) func NewListViewData(data arrow.ArrayData) *ListView { a := &ListView{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -793,7 +799,7 @@ var _ VarLenListLike = (*LargeListView)(nil) // NewLargeListViewData returns a new LargeListView array value, from data. func NewLargeListViewData(data arrow.ArrayData) *LargeListView { a := new(LargeListView) - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -931,8 +937,10 @@ type offsetsAndSizes interface { sizeAt(slot int64) int64 } -var _ offsetsAndSizes = (*ListView)(nil) -var _ offsetsAndSizes = (*LargeListView)(nil) +var ( + _ offsetsAndSizes = (*ListView)(nil) + _ offsetsAndSizes = (*LargeListView)(nil) +) func (a *ListView) offsetAt(slot int64) int64 { return int64(a.offsets[int64(a.data.offset)+slot]) } @@ -1081,9 +1089,9 @@ type LargeListViewBuilder struct { func NewListViewBuilder(mem memory.Allocator, etype arrow.DataType) *ListViewBuilder { offsetBldr := NewInt32Builder(mem) sizeBldr := NewInt32Builder(mem) - return &ListViewBuilder{ + lvb := &ListViewBuilder{ baseListViewBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, etype), offsets: offsetBldr, sizes: sizeBldr, @@ -1092,6 +1100,8 @@ func NewListViewBuilder(mem memory.Allocator, etype arrow.DataType) *ListViewBui appendSizeVal: func(s int) { sizeBldr.Append(int32(s)) }, }, } + lvb.refCount.Add(1) + return lvb } // NewListViewBuilderWithField takes a field to use for the child rather than just @@ -1099,9 +1109,9 @@ func NewListViewBuilder(mem memory.Allocator, etype arrow.DataType) *ListViewBui func NewListViewBuilderWithField(mem memory.Allocator, field arrow.Field) *ListViewBuilder { offsetBldr := NewInt32Builder(mem) sizeBldr := NewInt32Builder(mem) - return &ListViewBuilder{ + lvb := &ListViewBuilder{ baseListViewBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, field.Type), offsets: offsetBldr, sizes: sizeBldr, @@ -1110,6 +1120,8 @@ func NewListViewBuilderWithField(mem memory.Allocator, field arrow.Field) *ListV appendSizeVal: func(s int) { sizeBldr.Append(int32(s)) }, }, } + lvb.refCount.Add(1) + return lvb } func (b *baseListViewBuilder) Type() arrow.DataType { @@ -1131,9 +1143,9 @@ func (b *baseListViewBuilder) Type() arrow.DataType { func NewLargeListViewBuilder(mem memory.Allocator, etype arrow.DataType) *LargeListViewBuilder { offsetBldr := NewInt64Builder(mem) sizeBldr := NewInt64Builder(mem) - return &LargeListViewBuilder{ + llvb := &LargeListViewBuilder{ baseListViewBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, etype), offsets: offsetBldr, sizes: sizeBldr, @@ -1142,6 +1154,8 @@ func NewLargeListViewBuilder(mem memory.Allocator, etype arrow.DataType) *LargeL appendSizeVal: func(s int) { sizeBldr.Append(int64(s)) }, }, } + llvb.refCount.Add(1) + return llvb } // NewLargeListViewBuilderWithField takes a field rather than just an element type @@ -1149,9 +1163,9 @@ func NewLargeListViewBuilder(mem memory.Allocator, etype arrow.DataType) *LargeL func NewLargeListViewBuilderWithField(mem memory.Allocator, field arrow.Field) *LargeListViewBuilder { offsetBldr := NewInt64Builder(mem) sizeBldr := NewInt64Builder(mem) - return &LargeListViewBuilder{ + llvb := &LargeListViewBuilder{ baseListViewBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, values: NewBuilder(mem, field.Type), offsets: offsetBldr, sizes: sizeBldr, @@ -1160,14 +1174,17 @@ func NewLargeListViewBuilderWithField(mem memory.Allocator, field arrow.Field) * appendSizeVal: func(o int) { sizeBldr.Append(int64(o)) }, }, } + + llvb.refCount.Add(1) + return llvb } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *baseListViewBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil diff --git a/arrow/array/map.go b/arrow/array/map.go index 5609ccd0..da9a150b 100644 --- a/arrow/array/map.go +++ b/arrow/array/map.go @@ -37,7 +37,7 @@ var _ ListLike = (*Map)(nil) // NewMapData returns a new Map array value, from data func NewMapData(data arrow.ArrayData) *Map { a := &Map{List: &List{}} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } diff --git a/arrow/array/null.go b/arrow/array/null.go index 76e56a49..38b3b097 100644 --- a/arrow/array/null.go +++ b/arrow/array/null.go @@ -21,7 +21,6 @@ import ( "fmt" "reflect" "strings" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/internal/debug" @@ -37,7 +36,7 @@ type Null struct { // NewNull returns a new Null array value of size n. func NewNull(n int) *Null { a := &Null{} - a.refCount = 1 + a.refCount.Add(1) data := NewData( arrow.Null, n, []*memory.Buffer{nil}, @@ -53,7 +52,7 @@ func NewNull(n int) *Null { // NewNullData returns a new Null array value, from data. func NewNullData(data arrow.ArrayData) *Null { a := &Null{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -95,7 +94,9 @@ type NullBuilder struct { // NewNullBuilder returns a builder, using the provided memory allocator. func NewNullBuilder(mem memory.Allocator) *NullBuilder { - return &NullBuilder{builder: builder{refCount: 1, mem: mem}} + nb := &NullBuilder{builder: builder{mem: mem}} + nb.refCount.Add(1) + return nb } func (b *NullBuilder) Type() arrow.DataType { return arrow.Null } @@ -103,9 +104,9 @@ func (b *NullBuilder) Type() arrow.DataType { return arrow.Null } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *NullBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil diff --git a/arrow/array/numeric.gen.go b/arrow/array/numeric.gen.go deleted file mode 100644 index 7e94fe5c..00000000 --- a/arrow/array/numeric.gen.go +++ /dev/null @@ -1,1469 +0,0 @@ -// Code generated by array/numeric.gen.go.tmpl. DO NOT EDIT. - -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package array - -import ( - "fmt" - "math" - "strconv" - "strings" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/internal/json" -) - -// A type which represents an immutable sequence of int64 values. -type Int64 struct { - array - values []int64 -} - -// NewInt64Data creates a new Int64. -func NewInt64Data(data arrow.ArrayData) *Int64 { - a := &Int64{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Int64) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Int64) Value(i int) int64 { return a.values[i] } - -// Values returns the values. -func (a *Int64) Int64Values() []int64 { return a.values } - -// String returns a string representation of the array. -func (a *Int64) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Int64) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Int64Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Int64) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatInt(int64(a.Value(i)), 10) -} - -func (a *Int64) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return a.values[i] -} - -func (a *Int64) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - - if a.IsValid(i) { - vals[i] = a.values[i] - } else { - vals[i] = nil - } - - } - - return json.Marshal(vals) -} - -func arrayEqualInt64(left, right *Int64) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of uint64 values. -type Uint64 struct { - array - values []uint64 -} - -// NewUint64Data creates a new Uint64. -func NewUint64Data(data arrow.ArrayData) *Uint64 { - a := &Uint64{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Uint64) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Uint64) Value(i int) uint64 { return a.values[i] } - -// Values returns the values. -func (a *Uint64) Uint64Values() []uint64 { return a.values } - -// String returns a string representation of the array. -func (a *Uint64) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Uint64) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Uint64Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Uint64) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatUint(uint64(a.Value(i)), 10) -} - -func (a *Uint64) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return a.values[i] -} - -func (a *Uint64) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - - if a.IsValid(i) { - vals[i] = a.values[i] - } else { - vals[i] = nil - } - - } - - return json.Marshal(vals) -} - -func arrayEqualUint64(left, right *Uint64) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of float64 values. -type Float64 struct { - array - values []float64 -} - -// NewFloat64Data creates a new Float64. -func NewFloat64Data(data arrow.ArrayData) *Float64 { - a := &Float64{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Float64) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Float64) Value(i int) float64 { return a.values[i] } - -// Values returns the values. -func (a *Float64) Float64Values() []float64 { return a.values } - -// String returns a string representation of the array. -func (a *Float64) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Float64) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Float64Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Float64) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatFloat(float64(a.Value(i)), 'g', -1, 64) -} - -func (a *Float64) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return a.values[i] -} - -func (a *Float64) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - if !a.IsValid(i) { - vals[i] = nil - continue - } - - f := a.Value(i) - switch { - case math.IsNaN(f): - vals[i] = "NaN" - case math.IsInf(f, 1): - vals[i] = "+Inf" - case math.IsInf(f, -1): - vals[i] = "-Inf" - default: - vals[i] = f - } - - } - - return json.Marshal(vals) -} - -func arrayEqualFloat64(left, right *Float64) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of int32 values. -type Int32 struct { - array - values []int32 -} - -// NewInt32Data creates a new Int32. -func NewInt32Data(data arrow.ArrayData) *Int32 { - a := &Int32{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Int32) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Int32) Value(i int) int32 { return a.values[i] } - -// Values returns the values. -func (a *Int32) Int32Values() []int32 { return a.values } - -// String returns a string representation of the array. -func (a *Int32) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Int32) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Int32Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Int32) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatInt(int64(a.Value(i)), 10) -} - -func (a *Int32) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return a.values[i] -} - -func (a *Int32) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - - if a.IsValid(i) { - vals[i] = a.values[i] - } else { - vals[i] = nil - } - - } - - return json.Marshal(vals) -} - -func arrayEqualInt32(left, right *Int32) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of uint32 values. -type Uint32 struct { - array - values []uint32 -} - -// NewUint32Data creates a new Uint32. -func NewUint32Data(data arrow.ArrayData) *Uint32 { - a := &Uint32{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Uint32) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Uint32) Value(i int) uint32 { return a.values[i] } - -// Values returns the values. -func (a *Uint32) Uint32Values() []uint32 { return a.values } - -// String returns a string representation of the array. -func (a *Uint32) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Uint32) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Uint32Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Uint32) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatUint(uint64(a.Value(i)), 10) -} - -func (a *Uint32) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return a.values[i] -} - -func (a *Uint32) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - - if a.IsValid(i) { - vals[i] = a.values[i] - } else { - vals[i] = nil - } - - } - - return json.Marshal(vals) -} - -func arrayEqualUint32(left, right *Uint32) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of float32 values. -type Float32 struct { - array - values []float32 -} - -// NewFloat32Data creates a new Float32. -func NewFloat32Data(data arrow.ArrayData) *Float32 { - a := &Float32{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Float32) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Float32) Value(i int) float32 { return a.values[i] } - -// Values returns the values. -func (a *Float32) Float32Values() []float32 { return a.values } - -// String returns a string representation of the array. -func (a *Float32) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Float32) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Float32Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Float32) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatFloat(float64(a.Value(i)), 'g', -1, 32) -} - -func (a *Float32) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return a.values[i] -} - -func (a *Float32) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - if !a.IsValid(i) { - vals[i] = nil - continue - } - - f := a.Value(i) - v := strconv.FormatFloat(float64(f), 'g', -1, 32) - - switch v { - case "NaN", "+Inf", "-Inf": - vals[i] = v - default: - vals[i] = f - } - - } - - return json.Marshal(vals) -} - -func arrayEqualFloat32(left, right *Float32) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of int16 values. -type Int16 struct { - array - values []int16 -} - -// NewInt16Data creates a new Int16. -func NewInt16Data(data arrow.ArrayData) *Int16 { - a := &Int16{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Int16) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Int16) Value(i int) int16 { return a.values[i] } - -// Values returns the values. -func (a *Int16) Int16Values() []int16 { return a.values } - -// String returns a string representation of the array. -func (a *Int16) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Int16) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Int16Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Int16) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatInt(int64(a.Value(i)), 10) -} - -func (a *Int16) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return a.values[i] -} - -func (a *Int16) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - - if a.IsValid(i) { - vals[i] = a.values[i] - } else { - vals[i] = nil - } - - } - - return json.Marshal(vals) -} - -func arrayEqualInt16(left, right *Int16) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of uint16 values. -type Uint16 struct { - array - values []uint16 -} - -// NewUint16Data creates a new Uint16. -func NewUint16Data(data arrow.ArrayData) *Uint16 { - a := &Uint16{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Uint16) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Uint16) Value(i int) uint16 { return a.values[i] } - -// Values returns the values. -func (a *Uint16) Uint16Values() []uint16 { return a.values } - -// String returns a string representation of the array. -func (a *Uint16) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Uint16) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Uint16Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Uint16) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatUint(uint64(a.Value(i)), 10) -} - -func (a *Uint16) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return a.values[i] -} - -func (a *Uint16) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - - if a.IsValid(i) { - vals[i] = a.values[i] - } else { - vals[i] = nil - } - - } - - return json.Marshal(vals) -} - -func arrayEqualUint16(left, right *Uint16) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of int8 values. -type Int8 struct { - array - values []int8 -} - -// NewInt8Data creates a new Int8. -func NewInt8Data(data arrow.ArrayData) *Int8 { - a := &Int8{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Int8) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Int8) Value(i int) int8 { return a.values[i] } - -// Values returns the values. -func (a *Int8) Int8Values() []int8 { return a.values } - -// String returns a string representation of the array. -func (a *Int8) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Int8) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Int8Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Int8) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatInt(int64(a.Value(i)), 10) -} - -func (a *Int8) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return float64(a.values[i]) // prevent uint8 from being seen as binary data -} - -func (a *Int8) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - - if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data - } else { - vals[i] = nil - } - - } - - return json.Marshal(vals) -} - -func arrayEqualInt8(left, right *Int8) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of uint8 values. -type Uint8 struct { - array - values []uint8 -} - -// NewUint8Data creates a new Uint8. -func NewUint8Data(data arrow.ArrayData) *Uint8 { - a := &Uint8{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Uint8) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Uint8) Value(i int) uint8 { return a.values[i] } - -// Values returns the values. -func (a *Uint8) Uint8Values() []uint8 { return a.values } - -// String returns a string representation of the array. -func (a *Uint8) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Uint8) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Uint8Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Uint8) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return strconv.FormatUint(uint64(a.Value(i)), 10) -} - -func (a *Uint8) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - - return float64(a.values[i]) // prevent uint8 from being seen as binary data -} - -func (a *Uint8) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - - if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data - } else { - vals[i] = nil - } - - } - - return json.Marshal(vals) -} - -func arrayEqualUint8(left, right *Uint8) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of arrow.Time32 values. -type Time32 struct { - array - values []arrow.Time32 -} - -// NewTime32Data creates a new Time32. -func NewTime32Data(data arrow.ArrayData) *Time32 { - a := &Time32{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Time32) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Time32) Value(i int) arrow.Time32 { return a.values[i] } - -// Values returns the values. -func (a *Time32) Time32Values() []arrow.Time32 { return a.values } - -// String returns a string representation of the array. -func (a *Time32) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Time32) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Time32Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Time32) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return a.values[i].FormattedString(a.DataType().(*arrow.Time32Type).Unit) -} - -func (a *Time32) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - return a.values[i].ToTime(a.DataType().(*arrow.Time32Type).Unit).Format("15:04:05.999999999") -} - -func (a *Time32) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := range a.values { - vals[i] = a.GetOneForMarshal(i) - } - - return json.Marshal(vals) -} - -func arrayEqualTime32(left, right *Time32) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of arrow.Time64 values. -type Time64 struct { - array - values []arrow.Time64 -} - -// NewTime64Data creates a new Time64. -func NewTime64Data(data arrow.ArrayData) *Time64 { - a := &Time64{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Time64) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Time64) Value(i int) arrow.Time64 { return a.values[i] } - -// Values returns the values. -func (a *Time64) Time64Values() []arrow.Time64 { return a.values } - -// String returns a string representation of the array. -func (a *Time64) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Time64) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Time64Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Time64) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return a.values[i].FormattedString(a.DataType().(*arrow.Time64Type).Unit) -} - -func (a *Time64) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - return a.values[i].ToTime(a.DataType().(*arrow.Time64Type).Unit).Format("15:04:05.999999999") -} - -func (a *Time64) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := range a.values { - vals[i] = a.GetOneForMarshal(i) - } - - return json.Marshal(vals) -} - -func arrayEqualTime64(left, right *Time64) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of arrow.Date32 values. -type Date32 struct { - array - values []arrow.Date32 -} - -// NewDate32Data creates a new Date32. -func NewDate32Data(data arrow.ArrayData) *Date32 { - a := &Date32{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Date32) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Date32) Value(i int) arrow.Date32 { return a.values[i] } - -// Values returns the values. -func (a *Date32) Date32Values() []arrow.Date32 { return a.values } - -// String returns a string representation of the array. -func (a *Date32) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Date32) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Date32Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Date32) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return a.values[i].FormattedString() -} - -func (a *Date32) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - return a.values[i].ToTime().Format("2006-01-02") -} - -func (a *Date32) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := range a.values { - vals[i] = a.GetOneForMarshal(i) - } - - return json.Marshal(vals) -} - -func arrayEqualDate32(left, right *Date32) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of arrow.Date64 values. -type Date64 struct { - array - values []arrow.Date64 -} - -// NewDate64Data creates a new Date64. -func NewDate64Data(data arrow.ArrayData) *Date64 { - a := &Date64{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Date64) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Date64) Value(i int) arrow.Date64 { return a.values[i] } - -// Values returns the values. -func (a *Date64) Date64Values() []arrow.Date64 { return a.values } - -// String returns a string representation of the array. -func (a *Date64) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Date64) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.Date64Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Date64) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - return a.values[i].FormattedString() -} - -func (a *Date64) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - return a.values[i].ToTime().Format("2006-01-02") -} - -func (a *Date64) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := range a.values { - vals[i] = a.GetOneForMarshal(i) - } - - return json.Marshal(vals) -} - -func arrayEqualDate64(left, right *Date64) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -// A type which represents an immutable sequence of arrow.Duration values. -type Duration struct { - array - values []arrow.Duration -} - -// NewDurationData creates a new Duration. -func NewDurationData(data arrow.ArrayData) *Duration { - a := &Duration{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *Duration) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *Duration) Value(i int) arrow.Duration { return a.values[i] } - -// Values returns the values. -func (a *Duration) DurationValues() []arrow.Duration { return a.values } - -// String returns a string representation of the array. -func (a *Duration) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *Duration) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.DurationTraits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *Duration) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } - // return value and suffix as a string such as "12345ms" - return fmt.Sprintf("%d%s", a.values[i], a.DataType().(*arrow.DurationType).Unit) -} - -func (a *Duration) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - // return value and suffix as a string such as "12345ms" - return fmt.Sprintf("%d%s", a.values[i], a.DataType().(*arrow.DurationType).Unit.String()) -} - -func (a *Duration) MarshalJSON() ([]byte, error) { - vals := make([]interface{}, a.Len()) - for i := range a.values { - vals[i] = a.GetOneForMarshal(i) - } - - return json.Marshal(vals) -} - -func arrayEqualDuration(left, right *Duration) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} diff --git a/arrow/array/numeric.gen.go.tmpl b/arrow/array/numeric.gen.go.tmpl deleted file mode 100644 index df07f205..00000000 --- a/arrow/array/numeric.gen.go.tmpl +++ /dev/null @@ -1,192 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package array - -import ( - "fmt" - "strings" - "time" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/internal/json" -) - -{{range .In}} - -// A type which represents an immutable sequence of {{or .QualifiedType .Type}} values. -type {{.Name}} struct { - array - values []{{or .QualifiedType .Type}} -} - -// New{{.Name}}Data creates a new {{.Name}}. -func New{{.Name}}Data(data arrow.ArrayData) *{{.Name}} { - a := &{{.Name}}{} - a.refCount = 1 - a.setData(data.(*Data)) - return a -} - -// Reset resets the array for re-use. -func (a *{{.Name}}) Reset(data *Data) { - a.setData(data) -} - -// Value returns the value at the specified index. -func (a *{{.Name}}) Value(i int) {{or .QualifiedType .Type}} { return a.values[i] } - -// Values returns the values. -func (a *{{.Name}}) {{.Name}}Values() []{{or .QualifiedType .Type}} { return a.values } - -// String returns a string representation of the array. -func (a *{{.Name}}) String() string { - o := new(strings.Builder) - o.WriteString("[") - for i, v := range a.values { - if i > 0 { - fmt.Fprintf(o, " ") - } - switch { - case a.IsNull(i): - o.WriteString(NullValueStr) - default: - fmt.Fprintf(o, "%v", v) - } - } - o.WriteString("]") - return o.String() -} - -func (a *{{.Name}}) setData(data *Data) { - a.array.setData(data) - vals := data.buffers[1] - if vals != nil { - a.values = arrow.{{.Name}}Traits.CastFromBytes(vals.Bytes()) - beg := a.array.data.offset - end := beg + a.array.data.length - a.values = a.values[beg:end] - } -} - -func (a *{{.Name}}) ValueStr(i int) string { - if a.IsNull(i) { - return NullValueStr - } -{{if or (eq .Name "Date32") (eq .Name "Date64") -}} - return a.values[i].FormattedString() -{{else if or (eq .Name "Time32") (eq .Name "Time64") -}} - return a.values[i].FormattedString(a.DataType().(*{{.QualifiedType}}Type).Unit) -{{else if (eq .Name "Duration") -}} - // return value and suffix as a string such as "12345ms" - return fmt.Sprintf("%d%s", a.values[i], a.DataType().(*{{.QualifiedType}}Type).Unit) -{{else if or (eq .Name "Int8") (eq .Name "Int16") (eq .Name "Int32") (eq .Name "Int64") -}} - return strconv.FormatInt(int64(a.Value(i)), 10) -{{else if or (eq .Name "Uint8") (eq .Name "Uint16") (eq .Name "Uint32") (eq .Name "Uint64") -}} - return strconv.FormatUint(uint64(a.Value(i)), 10) -{{else if or (eq .Name "Float32") -}} - return strconv.FormatFloat(float64(a.Value(i)), 'g', -1, 32) -{{else if or (eq .Name "Float64") -}} - return strconv.FormatFloat(float64(a.Value(i)), 'g', -1, 64) -{{else}} - return fmt.Sprintf("%v", a.values[i]) -{{end -}} -} - -func (a *{{.Name}}) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } -{{if or (eq .Name "Date32") (eq .Name "Date64") -}} - return a.values[i].ToTime().Format("2006-01-02") -{{else if or (eq .Name "Time32") (eq .Name "Time64") -}} - return a.values[i].ToTime(a.DataType().(*{{.QualifiedType}}Type).Unit).Format("15:04:05.999999999") -{{else if (eq .Name "Duration") -}} - // return value and suffix as a string such as "12345ms" - return fmt.Sprintf("%d%s", a.values[i], a.DataType().(*{{.QualifiedType}}Type).Unit.String()) -{{else if (eq .Size "1")}} - return float64(a.values[i]) // prevent uint8 from being seen as binary data -{{else}} - return a.values[i] -{{end -}} -} - -func (a *{{.Name}}) MarshalJSON() ([]byte, error) { -{{if .QualifiedType -}} - vals := make([]interface{}, a.Len()) - for i := range a.values { - vals[i] = a.GetOneForMarshal(i) - } -{{else -}} - vals := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - {{if (eq .Name "Float32") -}} - if !a.IsValid(i) { - vals[i] = nil - continue - } - - f := a.Value(i) - v := strconv.FormatFloat(float64(f), 'g', -1, 32) - - switch v { - case "NaN", "+Inf", "-Inf": - vals[i] = v - default: - vals[i] = f - } - {{else if (eq .Name "Float64") -}} - if !a.IsValid(i) { - vals[i] = nil - continue - } - - f := a.Value(i) - switch { - case math.IsNaN(f): - vals[i] = "NaN" - case math.IsInf(f, 1): - vals[i] = "+Inf" - case math.IsInf(f, -1): - vals[i] = "-Inf" - default: - vals[i] = f - } - {{else}} - if a.IsValid(i) { - {{ if (eq .Size "1") }}vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data{{ else }}vals[i] = a.values[i]{{ end }} - } else { - vals[i] = nil - } - {{end}} - } -{{end}} - return json.Marshal(vals) -} - -func arrayEqual{{.Name}}(left, right *{{.Name}}) bool { - for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { - continue - } - if left.Value(i) != right.Value(i) { - return false - } - } - return true -} - -{{end}} diff --git a/arrow/array/numeric_generic.go b/arrow/array/numeric_generic.go new file mode 100644 index 00000000..016dc373 --- /dev/null +++ b/arrow/array/numeric_generic.go @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package array + +import ( + "fmt" + "strconv" + "strings" + "time" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/internal/json" +) + +type numericArray[T arrow.IntType | arrow.UintType | arrow.FloatType] struct { + array + values []T +} + +func newNumericData[T arrow.IntType | arrow.UintType | arrow.FloatType](data arrow.ArrayData) numericArray[T] { + a := numericArray[T]{} + a.refCount.Add(1) + a.setData(data.(*Data)) + return a +} + +func (a *numericArray[T]) Reset(data *Data) { + a.setData(data) +} + +func (a *numericArray[T]) Value(i int) T { return a.values[i] } +func (a *numericArray[T]) Values() []T { return a.values } +func (a *numericArray[T]) String() string { + o := new(strings.Builder) + o.WriteString("[") + for i, v := range a.values { + if i > 0 { + fmt.Fprintf(o, " ") + } + switch { + case a.IsNull(i): + o.WriteString(NullValueStr) + default: + fmt.Fprintf(o, "%v", v) + } + } + o.WriteString("]") + return o.String() +} + +func (a *numericArray[T]) setData(data *Data) { + a.array.setData(data) + vals := data.buffers[1] + if vals != nil { + a.values = arrow.GetData[T](vals.Bytes()) + beg := a.array.data.offset + end := beg + a.array.data.length + a.values = a.values[beg:end] + } +} + +func (a *numericArray[T]) ValueStr(i int) string { + if a.IsNull(i) { + return NullValueStr + } + + return fmt.Sprintf("%v", a.values[i]) +} + +func (a *numericArray[T]) GetOneForMarshal(i int) any { + if a.IsNull(i) { + return nil + } + + return a.values[i] +} + +func (a *numericArray[T]) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := range a.Len() { + if a.IsValid(i) { + vals[i] = a.values[i] + } else { + vals[i] = nil + } + } + return json.Marshal(vals) +} + +type oneByteArrs[T int8 | uint8] struct { + numericArray[T] +} + +func (a *oneByteArrs[T]) GetOneForMarshal(i int) any { + if a.IsNull(i) { + return nil + } + + return float64(a.values[i]) // prevent uint8/int8 from being seen as binary data +} + +func (a *oneByteArrs[T]) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := range a.Len() { + if a.IsValid(i) { + vals[i] = float64(a.values[i]) // prevent uint8/int8 from being seen as binary data + } else { + vals[i] = nil + } + } + return json.Marshal(vals) +} + +type floatArray[T float32 | float64] struct { + numericArray[T] +} + +func (a *floatArray[T]) ValueStr(i int) string { + if a.IsNull(i) { + return NullValueStr + } + + f := a.Value(i) + bitWidth := int(unsafe.Sizeof(f) * 8) + return strconv.FormatFloat(float64(a.Value(i)), 'g', -1, bitWidth) +} + +func (a *floatArray[T]) GetOneForMarshal(i int) any { + if a.IsNull(i) { + return nil + } + + f := a.Value(i) + bitWidth := int(unsafe.Sizeof(f) * 8) + v := strconv.FormatFloat(float64(a.Value(i)), 'g', -1, bitWidth) + switch v { + case "NaN", "+Inf", "-Inf": + return v + default: + return f + } +} + +func (a *floatArray[T]) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := range a.values { + vals[i] = a.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +type dateArray[T interface { + arrow.Date32 | arrow.Date64 + FormattedString() string + ToTime() time.Time +}] struct { + numericArray[T] +} + +func (d *dateArray[T]) MarshalJSON() ([]byte, error) { + vals := make([]any, d.Len()) + for i := range d.values { + vals[i] = d.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +func (d *dateArray[T]) ValueStr(i int) string { + if d.IsNull(i) { + return NullValueStr + } + + return d.values[i].FormattedString() +} + +func (d *dateArray[T]) GetOneForMarshal(i int) interface{} { + if d.IsNull(i) { + return nil + } + + return d.values[i].FormattedString() +} + +type timeType interface { + TimeUnit() arrow.TimeUnit +} + +type timeArray[T interface { + arrow.Time32 | arrow.Time64 + FormattedString(arrow.TimeUnit) string + ToTime(arrow.TimeUnit) time.Time +}] struct { + numericArray[T] +} + +func (a *timeArray[T]) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := range a.values { + vals[i] = a.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +func (a *timeArray[T]) ValueStr(i int) string { + if a.IsNull(i) { + return NullValueStr + } + + return a.values[i].FormattedString(a.DataType().(timeType).TimeUnit()) +} + +func (a *timeArray[T]) GetOneForMarshal(i int) interface{} { + if a.IsNull(i) { + return nil + } + + return a.values[i].ToTime(a.DataType().(timeType).TimeUnit()).Format("15:04:05.999999999") +} + +type Duration struct { + numericArray[arrow.Duration] +} + +func NewDurationData(data arrow.ArrayData) *Duration { + return &Duration{numericArray: newNumericData[arrow.Duration](data)} +} + +func (a *Duration) DurationValues() []arrow.Duration { return a.Values() } + +func (a *Duration) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := range a.values { + vals[i] = a.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +func (a *Duration) ValueStr(i int) string { + if a.IsNull(i) { + return NullValueStr + } + + return fmt.Sprintf("%d%s", a.values[i], a.DataType().(timeType).TimeUnit()) +} + +func (a *Duration) GetOneForMarshal(i int) any { + if a.IsNull(i) { + return nil + } + return fmt.Sprintf("%d%s", a.values[i], a.DataType().(timeType).TimeUnit()) +} + +type Int64 struct { + numericArray[int64] +} + +func NewInt64Data(data arrow.ArrayData) *Int64 { + return &Int64{numericArray: newNumericData[int64](data)} +} + +func (a *Int64) Int64Values() []int64 { return a.Values() } + +type Uint64 struct { + numericArray[uint64] +} + +func NewUint64Data(data arrow.ArrayData) *Uint64 { + return &Uint64{numericArray: newNumericData[uint64](data)} +} + +func (a *Uint64) Uint64Values() []uint64 { return a.Values() } + +type Float32 struct { + floatArray[float32] +} + +func NewFloat32Data(data arrow.ArrayData) *Float32 { + return &Float32{floatArray[float32]{newNumericData[float32](data)}} +} + +func (a *Float32) Float32Values() []float32 { return a.Values() } + +type Float64 struct { + floatArray[float64] +} + +func NewFloat64Data(data arrow.ArrayData) *Float64 { + return &Float64{floatArray: floatArray[float64]{newNumericData[float64](data)}} +} + +func (a *Float64) Float64Values() []float64 { return a.Values() } + +type Int32 struct { + numericArray[int32] +} + +func NewInt32Data(data arrow.ArrayData) *Int32 { + return &Int32{newNumericData[int32](data)} +} + +func (a *Int32) Int32Values() []int32 { return a.Values() } + +type Uint32 struct { + numericArray[uint32] +} + +func NewUint32Data(data arrow.ArrayData) *Uint32 { + return &Uint32{numericArray: newNumericData[uint32](data)} +} + +func (a *Uint32) Uint32Values() []uint32 { return a.Values() } + +type Int16 struct { + numericArray[int16] +} + +func NewInt16Data(data arrow.ArrayData) *Int16 { + return &Int16{newNumericData[int16](data)} +} + +func (a *Int16) Int16Values() []int16 { return a.Values() } + +type Uint16 struct { + numericArray[uint16] +} + +func NewUint16Data(data arrow.ArrayData) *Uint16 { + return &Uint16{numericArray: newNumericData[uint16](data)} +} + +func (a *Uint16) Uint16Values() []uint16 { return a.Values() } + +type Int8 struct { + oneByteArrs[int8] +} + +func NewInt8Data(data arrow.ArrayData) *Int8 { + return &Int8{oneByteArrs[int8]{newNumericData[int8](data)}} +} + +func (a *Int8) Int8Values() []int8 { return a.Values() } + +type Uint8 struct { + oneByteArrs[uint8] +} + +func NewUint8Data(data arrow.ArrayData) *Uint8 { + return &Uint8{oneByteArrs[uint8]{newNumericData[uint8](data)}} +} + +func (a *Uint8) Uint8Values() []uint8 { return a.Values() } + +type Time32 struct { + timeArray[arrow.Time32] +} + +func NewTime32Data(data arrow.ArrayData) *Time32 { + return &Time32{timeArray[arrow.Time32]{newNumericData[arrow.Time32](data)}} +} + +func (a *Time32) Time32Values() []arrow.Time32 { return a.Values() } + +type Time64 struct { + timeArray[arrow.Time64] +} + +func NewTime64Data(data arrow.ArrayData) *Time64 { + return &Time64{timeArray[arrow.Time64]{newNumericData[arrow.Time64](data)}} +} + +func (a *Time64) Time64Values() []arrow.Time64 { return a.Values() } + +type Date32 struct { + dateArray[arrow.Date32] +} + +func NewDate32Data(data arrow.ArrayData) *Date32 { + return &Date32{dateArray[arrow.Date32]{newNumericData[arrow.Date32](data)}} +} + +func (a *Date32) Date32Values() []arrow.Date32 { return a.Values() } + +type Date64 struct { + dateArray[arrow.Date64] +} + +func NewDate64Data(data arrow.ArrayData) *Date64 { + return &Date64{dateArray[arrow.Date64]{newNumericData[arrow.Date64](data)}} +} + +func (a *Date64) Date64Values() []arrow.Date64 { return a.Values() } + +func arrayEqualFixedWidth[T arrow.FixedWidthType](left, right arrow.TypedArray[T]) bool { + for i := range left.Len() { + if left.IsNull(i) { + continue + } + if left.Value(i) != right.Value(i) { + return false + } + } + return true +} diff --git a/arrow/array/numericbuilder.gen.go b/arrow/array/numericbuilder.gen.go index 1618dba0..be87fbf4 100644 --- a/arrow/array/numericbuilder.gen.go +++ b/arrow/array/numericbuilder.gen.go @@ -24,7 +24,6 @@ import ( "reflect" "strconv" "strings" - "sync/atomic" "time" "github.com/apache/arrow-go/v18/arrow" @@ -42,7 +41,9 @@ type Int64Builder struct { } func NewInt64Builder(mem memory.Allocator) *Int64Builder { - return &Int64Builder{builder: builder{refCount: 1, mem: mem}} + b := &Int64Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Int64Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Int64 } @@ -50,9 +51,9 @@ func (b *Int64Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Int64 // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Int64Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -281,7 +282,9 @@ type Uint64Builder struct { } func NewUint64Builder(mem memory.Allocator) *Uint64Builder { - return &Uint64Builder{builder: builder{refCount: 1, mem: mem}} + b := &Uint64Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Uint64Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Uint64 } @@ -289,9 +292,9 @@ func (b *Uint64Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Uint // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Uint64Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -520,7 +523,9 @@ type Float64Builder struct { } func NewFloat64Builder(mem memory.Allocator) *Float64Builder { - return &Float64Builder{builder: builder{refCount: 1, mem: mem}} + b := &Float64Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Float64Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Float64 } @@ -528,9 +533,9 @@ func (b *Float64Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Flo // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Float64Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -759,7 +764,9 @@ type Int32Builder struct { } func NewInt32Builder(mem memory.Allocator) *Int32Builder { - return &Int32Builder{builder: builder{refCount: 1, mem: mem}} + b := &Int32Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Int32Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Int32 } @@ -767,9 +774,9 @@ func (b *Int32Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Int32 // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Int32Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -998,7 +1005,9 @@ type Uint32Builder struct { } func NewUint32Builder(mem memory.Allocator) *Uint32Builder { - return &Uint32Builder{builder: builder{refCount: 1, mem: mem}} + b := &Uint32Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Uint32Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Uint32 } @@ -1006,9 +1015,9 @@ func (b *Uint32Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Uint // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Uint32Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -1237,7 +1246,9 @@ type Float32Builder struct { } func NewFloat32Builder(mem memory.Allocator) *Float32Builder { - return &Float32Builder{builder: builder{refCount: 1, mem: mem}} + b := &Float32Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Float32Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Float32 } @@ -1245,9 +1256,9 @@ func (b *Float32Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Flo // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Float32Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -1476,7 +1487,9 @@ type Int16Builder struct { } func NewInt16Builder(mem memory.Allocator) *Int16Builder { - return &Int16Builder{builder: builder{refCount: 1, mem: mem}} + b := &Int16Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Int16Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Int16 } @@ -1484,9 +1497,9 @@ func (b *Int16Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Int16 // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Int16Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -1715,7 +1728,9 @@ type Uint16Builder struct { } func NewUint16Builder(mem memory.Allocator) *Uint16Builder { - return &Uint16Builder{builder: builder{refCount: 1, mem: mem}} + b := &Uint16Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Uint16Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Uint16 } @@ -1723,9 +1738,9 @@ func (b *Uint16Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Uint // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Uint16Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -1954,7 +1969,9 @@ type Int8Builder struct { } func NewInt8Builder(mem memory.Allocator) *Int8Builder { - return &Int8Builder{builder: builder{refCount: 1, mem: mem}} + b := &Int8Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Int8Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Int8 } @@ -1962,9 +1979,9 @@ func (b *Int8Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Int8 } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Int8Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -2193,7 +2210,9 @@ type Uint8Builder struct { } func NewUint8Builder(mem memory.Allocator) *Uint8Builder { - return &Uint8Builder{builder: builder{refCount: 1, mem: mem}} + b := &Uint8Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Uint8Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Uint8 } @@ -2201,9 +2220,9 @@ func (b *Uint8Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Uint8 // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Uint8Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -2433,7 +2452,9 @@ type Time32Builder struct { } func NewTime32Builder(mem memory.Allocator, dtype *arrow.Time32Type) *Time32Builder { - return &Time32Builder{builder: builder{refCount: 1, mem: mem}, dtype: dtype} + b := &Time32Builder{builder: builder{mem: mem}, dtype: dtype} + b.refCount.Add(1) + return b } func (b *Time32Builder) Type() arrow.DataType { return b.dtype } @@ -2441,9 +2462,9 @@ func (b *Time32Builder) Type() arrow.DataType { return b.dtype } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Time32Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -2673,7 +2694,9 @@ type Time64Builder struct { } func NewTime64Builder(mem memory.Allocator, dtype *arrow.Time64Type) *Time64Builder { - return &Time64Builder{builder: builder{refCount: 1, mem: mem}, dtype: dtype} + b := &Time64Builder{builder: builder{mem: mem}, dtype: dtype} + b.refCount.Add(1) + return b } func (b *Time64Builder) Type() arrow.DataType { return b.dtype } @@ -2681,9 +2704,9 @@ func (b *Time64Builder) Type() arrow.DataType { return b.dtype } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Time64Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -2912,7 +2935,9 @@ type Date32Builder struct { } func NewDate32Builder(mem memory.Allocator) *Date32Builder { - return &Date32Builder{builder: builder{refCount: 1, mem: mem}} + b := &Date32Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Date32Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Date32 } @@ -2920,9 +2945,9 @@ func (b *Date32Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Date // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Date32Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -3151,7 +3176,9 @@ type Date64Builder struct { } func NewDate64Builder(mem memory.Allocator) *Date64Builder { - return &Date64Builder{builder: builder{refCount: 1, mem: mem}} + b := &Date64Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *Date64Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Date64 } @@ -3159,9 +3186,9 @@ func (b *Date64Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.Date // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *Date64Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -3391,7 +3418,9 @@ type DurationBuilder struct { } func NewDurationBuilder(mem memory.Allocator, dtype *arrow.DurationType) *DurationBuilder { - return &DurationBuilder{builder: builder{refCount: 1, mem: mem}, dtype: dtype} + b := &DurationBuilder{builder: builder{mem: mem}, dtype: dtype} + b.refCount.Add(1) + return b } func (b *DurationBuilder) Type() arrow.DataType { return b.dtype } @@ -3399,9 +3428,9 @@ func (b *DurationBuilder) Type() arrow.DataType { return b.dtype } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *DurationBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil diff --git a/arrow/array/numericbuilder.gen.go.tmpl b/arrow/array/numericbuilder.gen.go.tmpl index e84e095c..518b3d4c 100644 --- a/arrow/array/numericbuilder.gen.go.tmpl +++ b/arrow/array/numericbuilder.gen.go.tmpl @@ -38,14 +38,18 @@ type {{.Name}}Builder struct { {{if .Opt.Parametric}} func New{{.Name}}Builder(mem memory.Allocator, dtype *arrow.{{.Name}}Type) *{{.Name}}Builder { - return &{{.Name}}Builder{builder: builder{refCount:1, mem: mem}, dtype: dtype} + b := &{{.Name}}Builder{builder: builder{mem: mem}, dtype: dtype} + b.refCount.Add(1) + return b } func (b *{{.Name}}Builder) Type() arrow.DataType { return b.dtype } {{else}} func New{{.Name}}Builder(mem memory.Allocator) *{{.Name}}Builder { - return &{{.Name}}Builder{builder: builder{refCount:1, mem: mem}} + b := &{{.Name}}Builder{builder: builder{mem: mem}} + b.refCount.Add(1) + return b } func (b *{{.Name}}Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.{{.Name}} } @@ -54,9 +58,9 @@ func (b *{{.Name}}Builder) Type() arrow.DataType { return arrow.PrimitiveTypes.{ // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *{{.Name}}Builder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil diff --git a/arrow/array/numericbuilder.gen_test.go b/arrow/array/numericbuilder.gen_test.go index 677a5dd5..1336815b 100644 --- a/arrow/array/numericbuilder.gen_test.go +++ b/arrow/array/numericbuilder.gen_test.go @@ -230,6 +230,30 @@ func TestInt64Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestInt64BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewInt64Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewInt64Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestUint64StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -432,6 +456,30 @@ func TestUint64Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestUint64BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewUint64Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewUint64Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestFloat64StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -858,6 +906,30 @@ func TestInt32Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestInt32BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewInt32Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewInt32Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestUint32StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -1060,6 +1132,30 @@ func TestUint32Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestUint32BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewUint32Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewUint32Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestFloat32StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -1486,6 +1582,30 @@ func TestInt16Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestInt16BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewInt16Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewInt16Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestUint16StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -1688,6 +1808,30 @@ func TestUint16Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestUint16BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewUint16Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewUint16Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestInt8StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -1890,6 +2034,30 @@ func TestInt8Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestInt8BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewInt8Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewInt8Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestUint8StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -2092,6 +2260,30 @@ func TestUint8Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestUint8BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewUint8Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewUint8Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestTime32StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -2299,6 +2491,31 @@ func TestTime32Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestTime32BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + dtype := &arrow.Time32Type{Unit: arrow.Second} + bldr := array.NewTime32Builder(mem, dtype) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewTime32Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestTime64StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -2506,6 +2723,31 @@ func TestTime64Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestTime64BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + dtype := &arrow.Time64Type{Unit: arrow.Second} + bldr := array.NewTime64Builder(mem, dtype) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewTime64Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestDate32StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -2708,6 +2950,30 @@ func TestDate32Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestDate32BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewDate32Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewDate32Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestDate64StringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -2917,6 +3183,30 @@ func TestDate64Builder_Resize(t *testing.T) { assert.Equal(t, 5, ab.Len()) } +func TestDate64BuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + bldr := array.NewDate64Builder(mem) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewDate64Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} + func TestDurationStringRoundTrip(t *testing.T) { // 1. create array mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -3123,3 +3413,28 @@ func TestDurationBuilder_Resize(t *testing.T) { ab.Resize(32) assert.Equal(t, 5, ab.Len()) } + +func TestDurationBuilderUnmarshalJSON(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + dtype := &arrow.DurationType{Unit: arrow.Second} + bldr := array.NewDurationBuilder(mem, dtype) + defer bldr.Release() + + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.NewDurationArray() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +} diff --git a/arrow/array/numericbuilder.gen_test.go.tmpl b/arrow/array/numericbuilder.gen_test.go.tmpl index a5d58f48..86cc74a5 100644 --- a/arrow/array/numericbuilder.gen_test.go.tmpl +++ b/arrow/array/numericbuilder.gen_test.go.tmpl @@ -276,9 +276,17 @@ func Test{{.Name}}BuilderUnmarshalJSON(t *testing.T) { mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer mem.AssertSize(t, 0) +{{if .Opt.Parametric -}} + dtype := &arrow.{{.Name}}Type{Unit: arrow.Second} + bldr := array.New{{.Name}}Builder(mem, dtype) +{{else}} bldr := array.New{{.Name}}Builder(mem) +{{end -}} + defer bldr.Release() + +{{ if or (eq .Name "Float64") (eq .Name "Float32") -}} jsonstr := `[0, 1, "+Inf", 2, 3, "NaN", "NaN", 4, 5, "-Inf"]` err := bldr.UnmarshalJSON([]byte(jsonstr)) @@ -292,6 +300,23 @@ func Test{{.Name}}BuilderUnmarshalJSON(t *testing.T) { assert.False(t, math.IsInf(float64(arr.Value(0)), 0), arr.Value(0)) assert.True(t, math.IsInf(float64(arr.Value(2)), 1), arr.Value(2)) assert.True(t, math.IsNaN(float64(arr.Value(5))), arr.Value(5)) +{{else}} + jsonstr := `[0, 1, null, 2.3, -11]` + + err := bldr.UnmarshalJSON([]byte(jsonstr)) + assert.NoError(t, err) + + arr := bldr.New{{.Name}}Array() + defer arr.Release() + + assert.NotNil(t, arr) + + assert.Equal(t, int64(0), int64(arr.Value(0))) + assert.Equal(t, int64(1), int64(arr.Value(1))) + assert.True(t, arr.IsNull(2)) + assert.Equal(t, int64(2), int64(arr.Value(3))) + assert.Equal(t, int64(5), int64(arr.Len())) +{{end -}} } {{end}} diff --git a/arrow/array/record.go b/arrow/array/record.go index b8041e27..18a50ed0 100644 --- a/arrow/array/record.go +++ b/arrow/array/record.go @@ -19,6 +19,7 @@ package array import ( "bytes" "fmt" + "iter" "strings" "sync/atomic" @@ -42,7 +43,7 @@ type RecordReader interface { // simpleRecords is a simple iterator over a collection of records. type simpleRecords struct { - refCount int64 + refCount atomic.Int64 schema *arrow.Schema recs []arrow.Record @@ -52,11 +53,11 @@ type simpleRecords struct { // NewRecordReader returns a simple iterator over the given slice of records. func NewRecordReader(schema *arrow.Schema, recs []arrow.Record) (RecordReader, error) { rs := &simpleRecords{ - refCount: 1, - schema: schema, - recs: recs, - cur: nil, + schema: schema, + recs: recs, + cur: nil, } + rs.refCount.Add(1) for _, rec := range rs.recs { rec.Retain() @@ -75,16 +76,16 @@ func NewRecordReader(schema *arrow.Schema, recs []arrow.Record) (RecordReader, e // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (rs *simpleRecords) Retain() { - atomic.AddInt64(&rs.refCount, 1) + rs.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (rs *simpleRecords) Release() { - debug.Assert(atomic.LoadInt64(&rs.refCount) > 0, "too many releases") + debug.Assert(rs.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&rs.refCount, -1) == 0 { + if rs.refCount.Add(-1) == 0 { if rs.cur != nil { rs.cur.Release() } @@ -112,7 +113,7 @@ func (rs *simpleRecords) Err() error { return nil } // simpleRecord is a basic, non-lazy in-memory record batch. type simpleRecord struct { - refCount int64 + refCount atomic.Int64 schema *arrow.Schema @@ -126,11 +127,12 @@ type simpleRecord struct { // NewRecord panics if rows is larger than the height of the columns. func NewRecord(schema *arrow.Schema, cols []arrow.Array, nrows int64) arrow.Record { rec := &simpleRecord{ - refCount: 1, - schema: schema, - rows: nrows, - arrs: make([]arrow.Array, len(cols)), + schema: schema, + rows: nrows, + arrs: make([]arrow.Array, len(cols)), } + rec.refCount.Add(1) + copy(rec.arrs, cols) for _, arr := range rec.arrs { arr.Retain() @@ -210,16 +212,16 @@ func (rec *simpleRecord) validate() error { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (rec *simpleRecord) Retain() { - atomic.AddInt64(&rec.refCount, 1) + rec.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (rec *simpleRecord) Release() { - debug.Assert(atomic.LoadInt64(&rec.refCount) > 0, "too many releases") + debug.Assert(rec.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&rec.refCount, -1) == 0 { + if rec.refCount.Add(-1) == 0 { for _, arr := range rec.arrs { arr.Release() } @@ -273,7 +275,7 @@ func (rec *simpleRecord) MarshalJSON() ([]byte, error) { // RecordBuilder eases the process of building a Record, iteratively, from // a known Schema. type RecordBuilder struct { - refCount int64 + refCount atomic.Int64 mem memory.Allocator schema *arrow.Schema fields []Builder @@ -282,11 +284,11 @@ type RecordBuilder struct { // NewRecordBuilder returns a builder, using the provided memory allocator and a schema. func NewRecordBuilder(mem memory.Allocator, schema *arrow.Schema) *RecordBuilder { b := &RecordBuilder{ - refCount: 1, - mem: mem, - schema: schema, - fields: make([]Builder, schema.NumFields()), + mem: mem, + schema: schema, + fields: make([]Builder, schema.NumFields()), } + b.refCount.Add(1) for i := 0; i < schema.NumFields(); i++ { b.fields[i] = NewBuilder(b.mem, schema.Field(i).Type) @@ -298,14 +300,14 @@ func NewRecordBuilder(mem memory.Allocator, schema *arrow.Schema) *RecordBuilder // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (b *RecordBuilder) Retain() { - atomic.AddInt64(&b.refCount, 1) + b.refCount.Add(1) } // Release decreases the reference count by 1. func (b *RecordBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { for _, f := range b.fields { f.Release() } @@ -405,6 +407,84 @@ func (b *RecordBuilder) UnmarshalJSON(data []byte) error { return nil } +type iterReader struct { + refCount atomic.Int64 + + schema *arrow.Schema + cur arrow.Record + + next func() (arrow.Record, error, bool) + stop func() + + err error +} + +func (ir *iterReader) Schema() *arrow.Schema { return ir.schema } + +func (ir *iterReader) Retain() { ir.refCount.Add(1) } +func (ir *iterReader) Release() { + debug.Assert(ir.refCount.Load() > 0, "too many releases") + + if ir.refCount.Add(-1) == 0 { + ir.stop() + ir.schema, ir.next = nil, nil + if ir.cur != nil { + ir.cur.Release() + } + } +} + +func (ir *iterReader) Record() arrow.Record { return ir.cur } +func (ir *iterReader) Err() error { return ir.err } + +func (ir *iterReader) Next() bool { + if ir.cur != nil { + ir.cur.Release() + } + + var ok bool + ir.cur, ir.err, ok = ir.next() + if ir.err != nil { + ir.stop() + return false + } + + return ok +} + +// ReaderFromIter wraps a go iterator for arrow.Record + error into a RecordReader +// interface object for ease of use. +func ReaderFromIter(schema *arrow.Schema, itr iter.Seq2[arrow.Record, error]) RecordReader { + next, stop := iter.Pull2(itr) + rdr := &iterReader{ + schema: schema, + next: next, + stop: stop, + } + rdr.refCount.Add(1) + return rdr +} + +// IterFromReader converts a RecordReader interface into an iterator that +// you can use range on. The semantics are still important, if a record +// that is returned is desired to be utilized beyond the scope of an iteration +// then Retain must be called on it. +func IterFromReader(rdr RecordReader) iter.Seq2[arrow.Record, error] { + rdr.Retain() + return func(yield func(arrow.Record, error) bool) { + defer rdr.Release() + for rdr.Next() { + if !yield(rdr.Record(), nil) { + return + } + } + + if rdr.Err() != nil { + yield(nil, rdr.Err()) + } + } +} + var ( _ arrow.Record = (*simpleRecord)(nil) _ RecordReader = (*simpleRecords)(nil) diff --git a/arrow/array/record_test.go b/arrow/array/record_test.go index 91a31cb1..2a61bddf 100644 --- a/arrow/array/record_test.go +++ b/arrow/array/record_test.go @@ -301,33 +301,97 @@ func TestRecordReader(t *testing.T) { defer rec2.Release() recs := []arrow.Record{rec1, rec2} - itr, err := array.NewRecordReader(schema, recs) - if err != nil { - t.Fatal(err) - } - defer itr.Release() + t.Run("simple reader", func(t *testing.T) { + itr, err := array.NewRecordReader(schema, recs) + if err != nil { + t.Fatal(err) + } + defer itr.Release() - itr.Retain() - itr.Release() + itr.Retain() + itr.Release() - if got, want := itr.Schema(), schema; !got.Equal(want) { - t.Fatalf("invalid schema. got=%#v, want=%#v", got, want) - } + if got, want := itr.Schema(), schema; !got.Equal(want) { + t.Fatalf("invalid schema. got=%#v, want=%#v", got, want) + } - n := 0 - for itr.Next() { - n++ - if got, want := itr.Record(), recs[n-1]; !reflect.DeepEqual(got, want) { - t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", n-1, got, want) + n := 0 + for itr.Next() { + n++ + if got, want := itr.Record(), recs[n-1]; !reflect.DeepEqual(got, want) { + t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", n-1, got, want) + } + } + if err := itr.Err(); err != nil { + t.Fatalf("itr error: %#v", err) } - } - if err := itr.Err(); err != nil { - t.Fatalf("itr error: %#v", err) - } - if n != len(recs) { - t.Fatalf("invalid number of iterations. got=%d, want=%d", n, len(recs)) - } + if n != len(recs) { + t.Fatalf("invalid number of iterations. got=%d, want=%d", n, len(recs)) + } + }) + + t.Run("iter to reader", func(t *testing.T) { + itr := func(yield func(arrow.Record, error) bool) { + for _, r := range recs { + if !yield(r, nil) { + return + } + } + } + + rdr := array.ReaderFromIter(schema, itr) + defer rdr.Release() + + rdr.Retain() + rdr.Release() + + if got, want := rdr.Schema(), schema; !got.Equal(want) { + t.Fatalf("invalid schema. got=%#v, want=%#v", got, want) + } + + n := 0 + for rdr.Next() { + n++ + // facet of using the simple record reader with a slice + // by default it will release records when the reader is released + // leading to too many releases on the original record + // so we retain it to keep it from going away while the test runs + rdr.Record().Retain() + if got, want := rdr.Record(), recs[n-1]; !reflect.DeepEqual(got, want) { + t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", n-1, got, want) + } + } + if err := rdr.Err(); err != nil { + t.Fatalf("itr error: %#v", err) + } + + if n != len(recs) { + t.Fatalf("invalid number of iterations. got=%d, want=%d", n, len(recs)) + } + }) + + t.Run("reader to iter", func(t *testing.T) { + rdr, err := array.NewRecordReader(schema, recs) + if err != nil { + t.Fatal(err) + } + + itr := array.IterFromReader(rdr) + rdr.Release() + + n := 0 + for rec, err := range itr { + if err != nil { + t.Fatalf("itr error: %#v", err) + } + + n++ + if got, want := rec, recs[n-1]; !reflect.DeepEqual(got, want) { + t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", n-1, got, want) + } + } + }) for _, tc := range []struct { name string diff --git a/arrow/array/string.go b/arrow/array/string.go index 5197e77f..d42492d6 100644 --- a/arrow/array/string.go +++ b/arrow/array/string.go @@ -44,7 +44,7 @@ type String struct { // NewStringData constructs a new String array from data. func NewStringData(data arrow.ArrayData) *String { a := &String{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -191,7 +191,7 @@ type LargeString struct { // NewStringData constructs a new String array from data. func NewLargeStringData(data arrow.ArrayData) *LargeString { a := &LargeString{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -332,7 +332,7 @@ type StringView struct { func NewStringViewData(data arrow.ArrayData) *StringView { a := &StringView{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -715,4 +715,8 @@ var ( _ StringLike = (*String)(nil) _ StringLike = (*LargeString)(nil) _ StringLike = (*StringView)(nil) + + _ arrow.TypedArray[string] = (*String)(nil) + _ arrow.TypedArray[string] = (*LargeString)(nil) + _ arrow.TypedArray[string] = (*StringView)(nil) ) diff --git a/arrow/array/struct.go b/arrow/array/struct.go index 7f65f8d2..957947b3 100644 --- a/arrow/array/struct.go +++ b/arrow/array/struct.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" "strings" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" @@ -47,6 +46,13 @@ func NewStructArray(cols []arrow.Array, names []string) (*Struct, error) { // and provided fields. As opposed to NewStructArray, this allows you to provide // the full fields to utilize for the struct column instead of just the names. func NewStructArrayWithFields(cols []arrow.Array, fields []arrow.Field) (*Struct, error) { + return NewStructArrayWithFieldsAndNulls(cols, fields, nil, 0, 0) +} + +// NewStructArrayWithFieldsAndNulls is like NewStructArrayWithFields as a convenience function, +// but also takes in a null bitmap, the number of nulls, and an optional offset +// to use for creating the Struct Array. +func NewStructArrayWithFieldsAndNulls(cols []arrow.Array, fields []arrow.Field, nullBitmap *memory.Buffer, nullCount int, offset int) (*Struct, error) { if len(cols) != len(fields) { return nil, fmt.Errorf("%w: mismatching number of fields and child arrays", arrow.ErrInvalid) } @@ -64,15 +70,18 @@ func NewStructArrayWithFields(cols []arrow.Array, fields []arrow.Field) (*Struct return nil, fmt.Errorf("%w: mismatching data type for child #%d, field says '%s', got '%s'", arrow.ErrInvalid, i, fields[i].Type, c.DataType()) } - if !fields[i].Nullable && c.NullN() > 0 { - return nil, fmt.Errorf("%w: field says not-nullable, child #%d has nulls", - arrow.ErrInvalid, i) - } children[i] = c.Data() } - data := NewData(arrow.StructOf(fields...), length, []*memory.Buffer{nil}, children, 0, 0) + if nullBitmap == nil { + if nullCount > 0 { + return nil, fmt.Errorf("%w: null count is greater than 0 but null bitmap is nil", arrow.ErrInvalid) + } + nullCount = 0 + } + + data := NewData(arrow.StructOf(fields...), length-offset, []*memory.Buffer{nullBitmap}, children, nullCount, offset) defer data.Release() return NewStructData(data), nil } @@ -107,7 +116,7 @@ func NewStructArrayWithNulls(cols []arrow.Array, names []string, nullBitmap *mem // NewStructData returns a new Struct array value from data. func NewStructData(data arrow.ArrayData) *Struct { a := &Struct{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -256,10 +265,12 @@ type StructBuilder struct { // NewStructBuilder returns a builder, using the provided memory allocator. func NewStructBuilder(mem memory.Allocator, dtype *arrow.StructType) *StructBuilder { b := &StructBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, dtype: dtype, fields: make([]Builder, dtype.NumFields()), } + b.refCount.Add(1) + for i, f := range dtype.Fields() { b.fields[i] = NewBuilder(b.mem, f.Type) } @@ -278,9 +289,9 @@ func (b *StructBuilder) Type() arrow.DataType { // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *StructBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil diff --git a/arrow/array/struct_test.go b/arrow/array/struct_test.go index a06ba831..24f522ed 100644 --- a/arrow/array/struct_test.go +++ b/arrow/array/struct_test.go @@ -24,6 +24,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStructArray(t *testing.T) { @@ -530,3 +531,34 @@ func TestStructArrayUnmarshalJSONMissingFields(t *testing.T) { ) } } + +func TestCreateStructWithNulls(t *testing.T) { + pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer pool.AssertSize(t, 0) + + var ( + fields = []arrow.Field{ + {Name: "f1", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "f2", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + } + dtype = arrow.StructOf(fields...) + ) + + sb := array.NewStructBuilder(pool, dtype) + defer sb.Release() + + sb.AppendNulls(100) + + arr := sb.NewArray().(*array.Struct) + defer arr.Release() + + assert.EqualValues(t, 100, arr.Len()) + assert.EqualValues(t, 100, arr.NullN()) + + arr2, err := array.NewStructArrayWithFieldsAndNulls( + []arrow.Array{arr.Field(0), arr.Field(1)}, fields, arr.Data().Buffers()[0], arr.NullN(), 0) + require.NoError(t, err) + defer arr2.Release() + + assert.True(t, array.Equal(arr, arr2)) +} diff --git a/arrow/array/table.go b/arrow/array/table.go index 95ac67f2..367b1b10 100644 --- a/arrow/array/table.go +++ b/arrow/array/table.go @@ -85,7 +85,7 @@ func NewChunkedSlice(a *arrow.Chunked, i, j int64) *arrow.Chunked { // simpleTable is a basic, non-lazy in-memory table. type simpleTable struct { - refCount int64 + refCount atomic.Int64 rows int64 cols []arrow.Column @@ -101,11 +101,11 @@ type simpleTable struct { // NewTable panics if rows is larger than the height of the columns. func NewTable(schema *arrow.Schema, cols []arrow.Column, rows int64) arrow.Table { tbl := simpleTable{ - refCount: 1, - rows: rows, - cols: cols, - schema: schema, + rows: rows, + cols: cols, + schema: schema, } + tbl.refCount.Add(1) if tbl.rows < 0 { switch len(tbl.cols) { @@ -150,11 +150,11 @@ func NewTableFromSlice(schema *arrow.Schema, data [][]arrow.Array) arrow.Table { } tbl := simpleTable{ - refCount: 1, - schema: schema, - cols: cols, - rows: int64(cols[0].Len()), + schema: schema, + cols: cols, + rows: int64(cols[0].Len()), } + tbl.refCount.Add(1) defer func() { if r := recover(); r != nil { @@ -241,16 +241,16 @@ func (tbl *simpleTable) validate() { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (tbl *simpleTable) Retain() { - atomic.AddInt64(&tbl.refCount, 1) + tbl.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (tbl *simpleTable) Release() { - debug.Assert(atomic.LoadInt64(&tbl.refCount) > 0, "too many releases") + debug.Assert(tbl.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&tbl.refCount, -1) == 0 { + if tbl.refCount.Add(-1) == 0 { for i := range tbl.cols { tbl.cols[i].Release() } @@ -279,7 +279,7 @@ func (tbl *simpleTable) String() string { // TableReader is a Record iterator over a (possibly chunked) Table type TableReader struct { - refCount int64 + refCount atomic.Int64 tbl arrow.Table cur int64 // current row @@ -297,15 +297,15 @@ type TableReader struct { func NewTableReader(tbl arrow.Table, chunkSize int64) *TableReader { ncols := tbl.NumCols() tr := &TableReader{ - refCount: 1, - tbl: tbl, - cur: 0, - max: int64(tbl.NumRows()), - chksz: chunkSize, - chunks: make([]*arrow.Chunked, ncols), - slots: make([]int, ncols), - offsets: make([]int64, ncols), + tbl: tbl, + cur: 0, + max: int64(tbl.NumRows()), + chksz: chunkSize, + chunks: make([]*arrow.Chunked, ncols), + slots: make([]int, ncols), + offsets: make([]int64, ncols), } + tr.refCount.Add(1) tr.tbl.Retain() if tr.chksz <= 0 { @@ -383,16 +383,16 @@ func (tr *TableReader) Next() bool { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (tr *TableReader) Retain() { - atomic.AddInt64(&tr.refCount, 1) + tr.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (tr *TableReader) Release() { - debug.Assert(atomic.LoadInt64(&tr.refCount) > 0, "too many releases") + debug.Assert(tr.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&tr.refCount, -1) == 0 { + if tr.refCount.Add(-1) == 0 { tr.tbl.Release() for _, chk := range tr.chunks { chk.Release() diff --git a/arrow/array/timestamp.go b/arrow/array/timestamp.go index 37359db1..9f8ca478 100644 --- a/arrow/array/timestamp.go +++ b/arrow/array/timestamp.go @@ -21,7 +21,6 @@ import ( "fmt" "reflect" "strings" - "sync/atomic" "time" "github.com/apache/arrow-go/v18/arrow" @@ -40,7 +39,7 @@ type Timestamp struct { // NewTimestampData creates a new Timestamp from Data. func NewTimestampData(data arrow.ArrayData) *Timestamp { a := &Timestamp{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -53,8 +52,10 @@ func (a *Timestamp) Reset(data *Data) { // Value returns the value at the specified index. func (a *Timestamp) Value(i int) arrow.Timestamp { return a.values[i] } +func (a *Timestamp) Values() []arrow.Timestamp { return a.values } + // TimestampValues returns the values. -func (a *Timestamp) TimestampValues() []arrow.Timestamp { return a.values } +func (a *Timestamp) TimestampValues() []arrow.Timestamp { return a.Values() } // String returns a string representation of the array. func (a *Timestamp) String() string { @@ -132,7 +133,9 @@ type TimestampBuilder struct { } func NewTimestampBuilder(mem memory.Allocator, dtype *arrow.TimestampType) *TimestampBuilder { - return &TimestampBuilder{builder: builder{refCount: 1, mem: mem}, dtype: dtype} + tb := &TimestampBuilder{builder: builder{mem: mem}, dtype: dtype} + tb.refCount.Add(1) + return tb } func (b *TimestampBuilder) Type() arrow.DataType { return b.dtype } @@ -140,9 +143,9 @@ func (b *TimestampBuilder) Type() arrow.DataType { return b.dtype } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. func (b *TimestampBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.nullBitmap != nil { b.nullBitmap.Release() b.nullBitmap = nil @@ -375,6 +378,7 @@ func (b *TimestampBuilder) UnmarshalJSON(data []byte) error { } var ( - _ arrow.Array = (*Timestamp)(nil) - _ Builder = (*TimestampBuilder)(nil) + _ arrow.Array = (*Timestamp)(nil) + _ Builder = (*TimestampBuilder)(nil) + _ arrow.TypedArray[arrow.Timestamp] = (*Timestamp)(nil) ) diff --git a/arrow/array/union.go b/arrow/array/union.go index 6f3a9a6e..9c13af05 100644 --- a/arrow/array/union.go +++ b/arrow/array/union.go @@ -23,7 +23,6 @@ import ( "math" "reflect" "strings" - "sync/atomic" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/bitutil" @@ -246,7 +245,7 @@ func NewSparseUnion(dt *arrow.SparseUnionType, length int, children []arrow.Arra // NewSparseUnionData constructs a SparseUnion array from the given ArrayData object. func NewSparseUnionData(data arrow.ArrayData) *SparseUnion { a := &SparseUnion{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -506,7 +505,7 @@ func NewDenseUnion(dt *arrow.DenseUnionType, length int, children []arrow.Array, // NewDenseUnionData constructs a DenseUnion array from the given ArrayData object. func NewDenseUnionData(data arrow.ArrayData) *DenseUnion { a := &DenseUnion{} - a.refCount = 1 + a.refCount.Add(1) a.setData(data.(*Data)) return a } @@ -736,12 +735,12 @@ type unionBuilder struct { typesBuilder *int8BufferBuilder } -func newUnionBuilder(mem memory.Allocator, children []Builder, typ arrow.UnionType) unionBuilder { +func newUnionBuilder(mem memory.Allocator, children []Builder, typ arrow.UnionType) *unionBuilder { if children == nil { children = make([]Builder, 0) } b := unionBuilder{ - builder: builder{refCount: 1, mem: mem}, + builder: builder{mem: mem}, mode: typ.Mode(), codes: typ.TypeCodes(), children: children, @@ -750,6 +749,7 @@ func newUnionBuilder(mem memory.Allocator, children []Builder, typ arrow.UnionTy childFields: make([]arrow.Field, len(children)), typesBuilder: newInt8BufferBuilder(mem), } + b.refCount.Add(1) b.typeIDtoChildID[0] = arrow.InvalidUnionChildID for i := 1; i < len(b.typeIDtoChildID); i *= 2 { @@ -767,7 +767,7 @@ func newUnionBuilder(mem memory.Allocator, children []Builder, typ arrow.UnionTy b.typeIDtoBuilder[typeID] = c } - return b + return &b } func (b *unionBuilder) NumChildren() int { @@ -795,9 +795,9 @@ func (b *unionBuilder) reserve(elements int, resize func(int)) { } func (b *unionBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { for _, c := range b.children { c.Release() } @@ -854,7 +854,6 @@ func (b *unionBuilder) nextTypeID() arrow.UnionTypeCode { id := b.denseTypeID b.denseTypeID++ return id - } func (b *unionBuilder) newData() *Data { @@ -879,7 +878,7 @@ func (b *unionBuilder) newData() *Data { // that they have the correct number of preceding elements that have been // added to the builder beforehand. type SparseUnionBuilder struct { - unionBuilder + *unionBuilder } // NewEmptySparseUnionBuilder is a helper to construct a SparseUnionBuilder @@ -1109,7 +1108,7 @@ func (b *SparseUnionBuilder) UnmarshalOne(dec *json.Decoder) error { // methods. You can also add new types to the union on the fly by using // AppendChild. type DenseUnionBuilder struct { - unionBuilder + *unionBuilder offsetsBuilder *int32BufferBuilder } @@ -1228,9 +1227,9 @@ func (b *DenseUnionBuilder) Append(nextType arrow.UnionTypeCode) { } func (b *DenseUnionBuilder) Release() { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { for _, c := range b.children { c.Release() } diff --git a/arrow/avro/reader.go b/arrow/avro/reader.go index c19c9eda..023eabdd 100644 --- a/arrow/avro/reader.go +++ b/arrow/avro/reader.go @@ -54,7 +54,7 @@ type OCFReader struct { avroSchemaEdits []schemaEdit schema *arrow.Schema - refs int64 + refs atomic.Int64 bld *array.RecordBuilder bldMap *fieldPos ldr *dataLoader @@ -89,11 +89,12 @@ func NewOCFReader(r io.Reader, opts ...Option) (*OCFReader, error) { rr := &OCFReader{ r: ocfr, - refs: 1, chunk: 1, avroChanSize: 500, recChanSize: 10, } + rr.refs.Add(1) + for _, opt := range opts { opt(rr) } @@ -318,16 +319,16 @@ func WithChunk(n int) Option { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (r *OCFReader) Retain() { - atomic.AddInt64(&r.refs, 1) + r.refs.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (r *OCFReader) Release() { - debug.Assert(atomic.LoadInt64(&r.refs) > 0, "too many releases") + debug.Assert(r.refs.Load() > 0, "too many releases") - if atomic.AddInt64(&r.refs, -1) == 0 { + if r.refs.Add(-1) == 0 { if r.cur != nil { r.cur.Release() } diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go index 86f2b50e..b56edb4f 100644 --- a/arrow/cdata/cdata.go +++ b/arrow/cdata/cdata.go @@ -194,7 +194,8 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) { ret.Type = &arrow.DictionaryType{ IndexType: ret.Type, ValueType: valueField.Type, - Ordered: schema.dictionary.flags&C.ARROW_FLAG_DICTIONARY_ORDERED != 0} + Ordered: schema.dictionary.flags&C.ARROW_FLAG_DICTIONARY_ORDERED != 0, + } } return @@ -460,7 +461,7 @@ func (imp *cimporter) doImportArr(src *CArrowArray) error { // struct immediately after import, since we have no imported // memory that we have to track the lifetime of. defer func() { - if imp.alloc.bufCount == 0 { + if imp.alloc.bufCount.Load() == 0 { C.ArrowArrayRelease(imp.arr) C.free(unsafe.Pointer(imp.arr)) } @@ -662,9 +663,7 @@ func (imp *cimporter) importStringLike(offsetByteWidth int64) (err error) { return } - var ( - nulls, offsets, values *memory.Buffer - ) + var nulls, offsets, values *memory.Buffer if nulls, err = imp.importNullBitmap(0); err != nil { return } diff --git a/arrow/cdata/import_allocator.go b/arrow/cdata/import_allocator.go index ba5edf40..d2cc44b7 100644 --- a/arrow/cdata/import_allocator.go +++ b/arrow/cdata/import_allocator.go @@ -28,13 +28,13 @@ import ( import "C" type importAllocator struct { - bufCount int64 + bufCount atomic.Int64 arr *CArrowArray } func (i *importAllocator) addBuffer() { - atomic.AddInt64(&i.bufCount, 1) + i.bufCount.Add(1) } func (*importAllocator) Allocate(int) []byte { @@ -46,9 +46,9 @@ func (*importAllocator) Reallocate(int, []byte) []byte { } func (i *importAllocator) Free([]byte) { - debug.Assert(atomic.LoadInt64(&i.bufCount) > 0, "too many releases") + debug.Assert(i.bufCount.Load() > 0, "too many releases") - if atomic.AddInt64(&i.bufCount, -1) == 0 { + if i.bufCount.Add(-1) == 0 { defer C.free(unsafe.Pointer(i.arr)) C.ArrowArrayRelease(i.arr) if C.ArrowArrayIsReleased(i.arr) != 1 { diff --git a/arrow/compute/arithmetic_test.go b/arrow/compute/arithmetic_test.go index 6db02129..07fb1fc9 100644 --- a/arrow/compute/arithmetic_test.go +++ b/arrow/compute/arithmetic_test.go @@ -204,7 +204,7 @@ type BinaryArithmeticSuite[T arrow.NumericType] struct { scalarEqualOpts []scalar.EqualOption } -func (BinaryArithmeticSuite[T]) DataType() arrow.DataType { +func (*BinaryArithmeticSuite[T]) DataType() arrow.DataType { return arrow.GetDataType[T]() } diff --git a/arrow/compute/exec/kernel.go b/arrow/compute/exec/kernel.go index d7de176c..fd3a52d9 100644 --- a/arrow/compute/exec/kernel.go +++ b/arrow/compute/exec/kernel.go @@ -68,6 +68,7 @@ type NonAggKernel interface { GetNullHandling() NullHandling GetMemAlloc() MemAlloc CanFillSlices() bool + Cleanup() error } // KernelCtx is a small struct holding the context for a kernel execution @@ -604,6 +605,7 @@ type ScalarKernel struct { CanWriteIntoSlices bool NullHandling NullHandling MemAlloc MemAlloc + CleanupFn func(KernelState) error } // NewScalarKernel constructs a new kernel for scalar execution, constructing @@ -629,6 +631,13 @@ func NewScalarKernelWithSig(sig *KernelSignature, exec ArrayKernelExec, init Ker } } +func (s *ScalarKernel) Cleanup() error { + if s.CleanupFn != nil { + return s.CleanupFn(s.Data) + } + return nil +} + func (s *ScalarKernel) Exec(ctx *KernelCtx, sp *ExecSpan, out *ExecResult) error { return s.ExecFn(ctx, sp, out) } @@ -693,3 +702,4 @@ func (s *VectorKernel) Exec(ctx *KernelCtx, sp *ExecSpan, out *ExecResult) error func (s VectorKernel) GetNullHandling() NullHandling { return s.NullHandling } func (s VectorKernel) GetMemAlloc() MemAlloc { return s.MemAlloc } func (s VectorKernel) CanFillSlices() bool { return s.CanWriteIntoSlices } +func (s VectorKernel) Cleanup() error { return nil } diff --git a/arrow/compute/exec/utils.go b/arrow/compute/exec/utils.go index e3685205..58c4c0ce 100644 --- a/arrow/compute/exec/utils.go +++ b/arrow/compute/exec/utils.go @@ -158,7 +158,7 @@ func RechunkArraysConsistently(groups [][]arrow.Array) [][]arrow.Array { type ChunkResolver struct { offsets []int64 - cached int64 + cached atomic.Int64 } func NewChunkResolver(chunks []arrow.Array) *ChunkResolver { @@ -184,7 +184,7 @@ func (c *ChunkResolver) Resolve(idx int64) (chunk, index int64) { return 0, idx } - cached := atomic.LoadInt64(&c.cached) + cached := c.cached.Load() cacheHit := idx >= c.offsets[cached] && idx < c.offsets[cached+1] if cacheHit { return cached, idx - c.offsets[cached] @@ -196,7 +196,7 @@ func (c *ChunkResolver) Resolve(idx int64) (chunk, index int64) { } chunk, index = int64(chkIdx), idx-c.offsets[chkIdx] - atomic.StoreInt64(&c.cached, chunk) + c.cached.Store(chunk) return } @@ -214,7 +214,8 @@ type BoolIter struct { func NewBoolIter(arr *ArraySpan) ArrayIter[bool] { return &BoolIter{ - Rdr: bitutil.NewBitmapReader(arr.Buffers[1].Buf, int(arr.Offset), int(arr.Len))} + Rdr: bitutil.NewBitmapReader(arr.Buffers[1].Buf, int(arr.Offset), int(arr.Len)), + } } func (b *BoolIter) Next() (out bool) { diff --git a/arrow/compute/executor.go b/arrow/compute/executor.go index 54c65adc..bf41036e 100644 --- a/arrow/compute/executor.go +++ b/arrow/compute/executor.go @@ -20,6 +20,7 @@ package compute import ( "context" + "errors" "fmt" "math" "runtime" @@ -579,6 +580,10 @@ func (s *scalarExecutor) WrapResults(ctx context.Context, out <-chan Datum, hasC } func (s *scalarExecutor) executeSpans(data chan<- Datum) (err error) { + defer func() { + err = errors.Join(err, s.kernel.Cleanup()) + }() + var ( input exec.ExecSpan output exec.ExecResult @@ -645,7 +650,7 @@ func (s *scalarExecutor) executeSingleSpan(input *exec.ExecSpan, out *exec.ExecR return s.kernel.Exec(s.ctx, input, out) } -func (s *scalarExecutor) setupPrealloc(totalLen int64, args []Datum) error { +func (s *scalarExecutor) setupPrealloc(_ int64, args []Datum) error { s.numOutBuf = len(s.outType.Layout().Buffers) outTypeID := s.outType.ID() // default to no validity pre-allocation for the following cases: diff --git a/arrow/compute/expression.go b/arrow/compute/expression.go index 88e1dde3..4e60d38c 100644 --- a/arrow/compute/expression.go +++ b/arrow/compute/expression.go @@ -490,6 +490,7 @@ func Cast(ex Expression, dt arrow.DataType) Expression { return NewCall("cast", []Expression{ex}, opts) } +// Deprecated: Use SetOptions instead type SetLookupOptions struct { ValueSet Datum `compute:"value_set"` SkipNulls bool `compute:"skip_nulls"` diff --git a/arrow/compute/exprs/exec.go b/arrow/compute/exprs/exec.go index 2e643815..0d0a139f 100644 --- a/arrow/compute/exprs/exec.go +++ b/arrow/compute/exprs/exec.go @@ -524,7 +524,6 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E err error allScalar = true args = make([]compute.Datum, e.NArgs()) - argTypes = make([]arrow.DataType, e.NArgs()) ) for i := 0; i < e.NArgs(); i++ { switch v := e.Arg(i).(type) { @@ -543,20 +542,23 @@ func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.E default: return nil, arrow.ErrNotImplemented } - - argTypes[i] = args[i].(compute.ArrayLikeDatum).Type() } _, conv, ok := ext.DecodeFunction(e.FuncRef()) if !ok { - return nil, arrow.ErrNotImplemented + return nil, fmt.Errorf("%w: %s", arrow.ErrNotImplemented, e.Name()) } - fname, opts, err := conv(e) + fname, args, opts, err := conv(e, args) if err != nil { return nil, err } + argTypes := make([]arrow.DataType, len(args)) + for i, arg := range args { + argTypes[i] = arg.(compute.ArrayLikeDatum).Type() + } + ectx := compute.GetExecCtx(ctx) fn, ok := ectx.Registry.GetFunction(fname) if !ok { diff --git a/arrow/compute/exprs/extension_types.go b/arrow/compute/exprs/extension_types.go index db780cbe..f44da19e 100644 --- a/arrow/compute/exprs/extension_types.go +++ b/arrow/compute/exprs/extension_types.go @@ -26,6 +26,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/extensions" ) type simpleExtensionTypeFactory[P comparable] struct { @@ -95,13 +96,6 @@ type simpleExtensionArrayFactory[P comparable] struct { array.ExtensionArrayBase } -type uuidExtParams struct{} - -var uuidType = simpleExtensionTypeFactory[uuidExtParams]{ - name: "arrow.uuid", getStorage: func(uuidExtParams) arrow.DataType { - return &arrow.FixedSizeBinaryType{ByteWidth: 16} - }} - type fixedCharExtensionParams struct { Length int32 `json:"length"` } @@ -138,7 +132,7 @@ var intervalDayType = simpleExtensionTypeFactory[intervalDayExtensionParams]{ }, } -func uuid() arrow.DataType { return uuidType.CreateType(uuidExtParams{}) } +func uuid() arrow.DataType { return extensions.NewUUIDType() } func fixedChar(length int32) arrow.DataType { return fixedCharType.CreateType(fixedCharExtensionParams{Length: length}) } diff --git a/arrow/compute/exprs/types.go b/arrow/compute/exprs/types.go index f48a6c56..0c468f35 100644 --- a/arrow/compute/exprs/types.go +++ b/arrow/compute/exprs/types.go @@ -26,6 +26,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/compute" + "github.com/apache/arrow-go/v18/arrow/scalar" "github.com/substrait-io/substrait-go/v3/expr" "github.com/substrait-io/substrait-go/v3/extensions" "github.com/substrait-io/substrait-go/v3/types" @@ -41,7 +42,8 @@ const ( SubstraitComparisonFuncsURI = SubstraitDefaultURIPrefix + "functions_comparison.yaml" SubstraitBooleanFuncsURI = SubstraitDefaultURIPrefix + "functions_boolean.yaml" - TimestampTzTimezone = "UTC" + SubstraitIcebergSetFuncURI = "https://github.com/apache/iceberg-go/blob/main/table/substrait/functions_set.yaml" + TimestampTzTimezone = "UTC" ) var hashSeed maphash.Seed @@ -127,6 +129,15 @@ func init() { panic(err) } } + + for _, fn := range []string{"is_in"} { + err := DefaultExtensionIDRegistry.AddSubstraitScalarToArrow( + extensions.ID{URI: SubstraitIcebergSetFuncURI, Name: fn}, + setLookupFuncSubstraitToArrowFunc) + if err != nil { + panic(err) + } + } } type overflowBehavior string @@ -178,7 +189,7 @@ func parseOption[typ ~string](sf *expr.ScalarFunction, optionName string, parser return def, arrow.ErrNotImplemented } -type substraitToArrow = func(*expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) +type substraitToArrow = func(*expr.ScalarFunction, []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) type arrowToSubstrait = func(fname string) (extensions.ID, []*types.FunctionOption, error) var substraitToArrowFuncMap = map[string]string{ @@ -199,7 +210,32 @@ var arrowToSubstraitFuncMap = map[string]string{ "or_kleene": "or", } -func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) { +func setLookupFuncSubstraitToArrowFunc(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) { + fname, _, _ = strings.Cut(sf.Name(), ":") + f, ok := substraitToArrowFuncMap[fname] + if ok { + fname = f + } + + setopts := &compute.SetOptions{ + NullBehavior: compute.NullMatchingMatch, + } + switch input[1].Kind() { + case compute.KindArray, compute.KindChunked: + setopts.ValueSet = input[1] + case compute.KindScalar: + // should be a list scalar + setopts.ValueSet = compute.NewDatumWithoutOwning( + input[1].(*compute.ScalarDatum).Value.(*scalar.List).Value) + } + + args, opts = input[0:1], setopts + return +} + +func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) { + args = input + fname, _, _ = strings.Cut(sf.Name(), ":") f, ok := substraitToArrowFuncMap[fname] if ok { @@ -219,19 +255,19 @@ func simpleMapArrowToSubstraitFunc(uri string) arrowToSubstrait { } func decodeOptionlessOverflowableArithmetic(n string) substraitToArrow { - return func(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) { + return func(sf *expr.ScalarFunction, input []compute.Datum) (fname string, args []compute.Datum, opts compute.FunctionOptions, err error) { overflow, err := parseOption(sf, "overflow", &overflowParser, []overflowBehavior{overflowSILENT, overflowERROR}, overflowSILENT) if err != nil { - return n, nil, err + return n, input, nil, err } switch overflow { case overflowSILENT: - return n + "_unchecked", nil, nil + return n + "_unchecked", input, nil, nil case overflowERROR: - return n, nil, nil + return n, input, nil, nil default: - return n, nil, arrow.ErrNotImplemented + return n, input, nil, arrow.ErrNotImplemented } } } diff --git a/arrow/compute/internal/kernels/scalar_set_lookup.go b/arrow/compute/internal/kernels/scalar_set_lookup.go new file mode 100644 index 00000000..c3562675 --- /dev/null +++ b/arrow/compute/internal/kernels/scalar_set_lookup.go @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernels + +import ( + "fmt" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/compute/exec" + "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/bitutils" + "github.com/apache/arrow-go/v18/internal/hashing" +) + +type NullMatchingBehavior int8 + +const ( + NullMatchingMatch NullMatchingBehavior = iota + NullMatchingSkip + NullMatchingEmitNull + NullMatchingInconclusive +) + +func visitBinary[OffsetT int32 | int64](data *exec.ArraySpan, valid func([]byte) error, null func() error) error { + if data.Len == 0 { + return nil + } + + rawBytes := data.Buffers[2].Buf + offsets := exec.GetSpanOffsets[OffsetT](data, 1) + return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len, + func(pos int64) error { + return valid(rawBytes[offsets[pos]:offsets[pos+1]]) + }, null) +} + +func visitNumeric[T arrow.FixedWidthType](data *exec.ArraySpan, valid func(T) error, null func() error) error { + if data.Len == 0 { + return nil + } + + values := exec.GetSpanValues[T](data, 1) + return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len, + func(pos int64) error { + return valid(values[pos]) + }, null) +} + +func visitFSB(data *exec.ArraySpan, valid func([]byte) error, null func() error) error { + if data.Len == 0 { + return nil + } + + sz := int64(data.Type.(arrow.FixedWidthDataType).Bytes()) + rawBytes := data.Buffers[1].Buf + + return bitutils.VisitBitBlocksShort(data.Buffers[0].Buf, data.Offset, data.Len, + func(pos int64) error { + return valid(rawBytes[pos*sz : (pos+1)*sz]) + }, null) +} + +type SetLookupOptions struct { + ValueSetType arrow.DataType + TotalLen int64 + ValueSet []exec.ArraySpan + NullBehavior NullMatchingBehavior +} + +type lookupState interface { + Init(SetLookupOptions) error +} + +func CreateSetLookupState(opts SetLookupOptions, alloc memory.Allocator) (exec.KernelState, error) { + valueSetType := opts.ValueSetType + if valueSetType.ID() == arrow.EXTENSION { + valueSetType = valueSetType.(arrow.ExtensionType).StorageType() + } + + var state lookupState + switch ty := valueSetType.(type) { + case arrow.BinaryDataType: + switch ty.Layout().Buffers[1].ByteWidth { + case 4: + state = &SetLookupState[[]byte]{ + Alloc: alloc, + visitFn: visitBinary[int32], + } + case 8: + state = &SetLookupState[[]byte]{ + Alloc: alloc, + visitFn: visitBinary[int64], + } + } + case arrow.FixedWidthDataType: + switch ty.Bytes() { + case 1: + state = &SetLookupState[uint8]{ + Alloc: alloc, + visitFn: visitNumeric[uint8], + } + case 2: + state = &SetLookupState[uint16]{ + Alloc: alloc, + visitFn: visitNumeric[uint16], + } + case 4: + state = &SetLookupState[uint32]{ + Alloc: alloc, + visitFn: visitNumeric[uint32], + } + case 8: + state = &SetLookupState[uint64]{ + Alloc: alloc, + visitFn: visitNumeric[uint64], + } + default: + state = &SetLookupState[[]byte]{ + Alloc: alloc, + visitFn: visitFSB, + } + } + + default: + return nil, fmt.Errorf("%w: unsupported type %s for SetLookup functions", arrow.ErrInvalid, opts.ValueSetType) + } + + return state, state.Init(opts) +} + +type SetLookupState[T hashing.MemoTypes] struct { + visitFn func(*exec.ArraySpan, func(T) error, func() error) error + ValueSetType arrow.DataType + Alloc memory.Allocator + Lookup hashing.TypedMemoTable[T] + // When there are duplicates in value set, memotable indices + // must be mapped back to indices in the value set + MemoIndexToValueIndex []int32 + NullIndex int32 + NullBehavior NullMatchingBehavior +} + +func (s *SetLookupState[T]) ValueType() arrow.DataType { + return s.ValueSetType +} + +func (s *SetLookupState[T]) Init(opts SetLookupOptions) error { + s.ValueSetType = opts.ValueSetType + s.NullBehavior = opts.NullBehavior + s.MemoIndexToValueIndex = make([]int32, 0, opts.TotalLen) + s.NullIndex = -1 + memoType := s.ValueSetType.ID() + if memoType == arrow.EXTENSION { + memoType = s.ValueSetType.(arrow.ExtensionType).StorageType().ID() + } + lookup, err := newMemoTable(s.Alloc, memoType) + if err != nil { + return err + } + s.Lookup = lookup.(hashing.TypedMemoTable[T]) + if s.Lookup == nil { + return fmt.Errorf("unsupported type %s for SetLookup functions", s.ValueSetType) + } + + var offset int64 + for _, c := range opts.ValueSet { + if err := s.AddArrayValueSet(&c, offset); err != nil { + return err + } + offset += c.Len + } + + lookupNull, _ := s.Lookup.GetNull() + if s.NullBehavior != NullMatchingSkip && lookupNull >= 0 { + s.NullIndex = int32(lookupNull) + } + return nil +} + +func (s *SetLookupState[T]) AddArrayValueSet(data *exec.ArraySpan, startIdx int64) error { + idx := startIdx + return s.visitFn(data, + func(v T) error { + memoSize := len(s.MemoIndexToValueIndex) + memoIdx, found, err := s.Lookup.InsertOrGet(v) + if err != nil { + return err + } + + if !found { + debug.Assert(memoIdx == memoSize, "inconsistent memo index and size") + s.MemoIndexToValueIndex = append(s.MemoIndexToValueIndex, int32(idx)) + } else { + debug.Assert(memoIdx < memoSize, "inconsistent memo index and size") + } + + idx++ + return nil + }, func() error { + memoSize := len(s.MemoIndexToValueIndex) + nullIdx, found := s.Lookup.GetOrInsertNull() + if !found { + debug.Assert(nullIdx == memoSize, "inconsistent memo index and size") + s.MemoIndexToValueIndex = append(s.MemoIndexToValueIndex, int32(idx)) + } else { + debug.Assert(nullIdx < memoSize, "inconsistent memo index and size") + } + + idx++ + return nil + }) +} + +func DispatchIsIn(state lookupState, in *exec.ArraySpan, out *exec.ExecResult) error { + inType := in.Type + if inType.ID() == arrow.EXTENSION { + inType = inType.(arrow.ExtensionType).StorageType() + } + + switch ty := inType.(type) { + case arrow.BinaryDataType: + return isInKernelExec(state.(*SetLookupState[[]byte]), in, out) + case arrow.FixedWidthDataType: + switch ty.Bytes() { + case 1: + return isInKernelExec(state.(*SetLookupState[uint8]), in, out) + case 2: + return isInKernelExec(state.(*SetLookupState[uint16]), in, out) + case 4: + return isInKernelExec(state.(*SetLookupState[uint32]), in, out) + case 8: + return isInKernelExec(state.(*SetLookupState[uint64]), in, out) + default: + return isInKernelExec(state.(*SetLookupState[[]byte]), in, out) + } + default: + return fmt.Errorf("%w: unsupported type %s for is_in function", arrow.ErrInvalid, in.Type) + } +} + +func isInKernelExec[T hashing.MemoTypes](state *SetLookupState[T], in *exec.ArraySpan, out *exec.ExecResult) error { + writerBool := bitutil.NewBitmapWriter(out.Buffers[1].Buf, int(out.Offset), int(out.Len)) + defer writerBool.Finish() + writerNulls := bitutil.NewBitmapWriter(out.Buffers[0].Buf, int(out.Offset), int(out.Len)) + defer writerNulls.Finish() + valueSetHasNull := state.NullIndex != -1 + return state.visitFn(in, + func(v T) error { + switch { + case state.Lookup.Exists(v): + writerBool.Set() + writerNulls.Set() + case state.NullBehavior == NullMatchingInconclusive && valueSetHasNull: + writerBool.Clear() + writerNulls.Clear() + default: + writerBool.Clear() + writerNulls.Set() + } + + writerBool.Next() + writerNulls.Next() + return nil + }, func() error { + switch { + case state.NullBehavior == NullMatchingMatch && valueSetHasNull: + writerBool.Set() + writerNulls.Set() + case state.NullBehavior == NullMatchingSkip || (!valueSetHasNull && state.NullBehavior == NullMatchingMatch): + writerBool.Clear() + writerNulls.Set() + default: + writerBool.Clear() + writerNulls.Clear() + } + + writerBool.Next() + writerNulls.Next() + return nil + }) +} diff --git a/arrow/compute/internal/kernels/vector_hash.go b/arrow/compute/internal/kernels/vector_hash.go index 51968f79..bb0c561a 100644 --- a/arrow/compute/internal/kernels/vector_hash.go +++ b/arrow/compute/internal/kernels/vector_hash.go @@ -345,10 +345,10 @@ func newMemoTable(mem memory.Allocator, dt arrow.Type) (hashing.MemoTable, error return hashing.NewUint8MemoTable(0), nil case arrow.INT16, arrow.UINT16: return hashing.NewUint16MemoTable(0), nil - case arrow.INT32, arrow.UINT32, arrow.FLOAT32, + case arrow.INT32, arrow.UINT32, arrow.FLOAT32, arrow.DECIMAL32, arrow.DATE32, arrow.TIME32, arrow.INTERVAL_MONTHS: return hashing.NewUint32MemoTable(0), nil - case arrow.INT64, arrow.UINT64, arrow.FLOAT64, + case arrow.INT64, arrow.UINT64, arrow.FLOAT64, arrow.DECIMAL64, arrow.DATE64, arrow.TIME64, arrow.TIMESTAMP, arrow.DURATION, arrow.INTERVAL_DAY_TIME: return hashing.NewUint64MemoTable(0), nil @@ -481,7 +481,7 @@ func uniqueFinalize(ctx *exec.KernelCtx, results []*exec.ArraySpan) ([]*exec.Arr return []*exec.ArraySpan{&out}, nil } -func ensureHashDictionary(ctx *exec.KernelCtx, hash *dictionaryHashState) (*exec.ArraySpan, error) { +func ensureHashDictionary(_ *exec.KernelCtx, hash *dictionaryHashState) (*exec.ArraySpan, error) { out := &exec.ArraySpan{} if hash.dictionary != nil { diff --git a/arrow/compute/internal/kernels/vector_selection.go b/arrow/compute/internal/kernels/vector_selection.go index 4a619406..9bbc8635 100644 --- a/arrow/compute/internal/kernels/vector_selection.go +++ b/arrow/compute/internal/kernels/vector_selection.go @@ -906,13 +906,13 @@ func takeIdxDispatch[ValT arrow.IntType](values, indices *exec.ArraySpan, out *e switch indices.Type.(arrow.FixedWidthDataType).Bytes() { case 1: - primitiveTakeImpl[uint8, ValT](getter, indices, out) + primitiveTakeImpl[uint8](getter, indices, out) case 2: - primitiveTakeImpl[uint16, ValT](getter, indices, out) + primitiveTakeImpl[uint16](getter, indices, out) case 4: - primitiveTakeImpl[uint32, ValT](getter, indices, out) + primitiveTakeImpl[uint32](getter, indices, out) case 8: - primitiveTakeImpl[uint64, ValT](getter, indices, out) + primitiveTakeImpl[uint64](getter, indices, out) default: return fmt.Errorf("%w: invalid indices byte width", arrow.ErrIndex) } @@ -1147,7 +1147,7 @@ func filterExec(ctx *exec.KernelCtx, outputLen int64, values, selection *exec.Ar return nil } -func binaryFilterNonNull[OffsetT int32 | int64](ctx *exec.KernelCtx, values, filter *exec.ArraySpan, outputLen int64, nullSelection NullSelectionBehavior, out *exec.ExecResult) error { +func binaryFilterNonNull[OffsetT int32 | int64](ctx *exec.KernelCtx, values, filter *exec.ArraySpan, outputLen int64, _ NullSelectionBehavior, out *exec.ExecResult) error { var ( offsetBuilder = newBufferBuilder[OffsetT](exec.GetAllocator(ctx.Ctx)) dataBuilder = newBufferBuilder[uint8](exec.GetAllocator(ctx.Ctx)) diff --git a/arrow/compute/registry.go b/arrow/compute/registry.go index 12bc0b85..6b9250c9 100644 --- a/arrow/compute/registry.go +++ b/arrow/compute/registry.go @@ -53,6 +53,7 @@ func GetFunctionRegistry() FunctionRegistry { RegisterScalarComparisons(registry) RegisterVectorHash(registry) RegisterVectorRunEndFuncs(registry) + RegisterScalarSetLookup(registry) }) return registry } diff --git a/arrow/compute/scalar_set_lookup.go b/arrow/compute/scalar_set_lookup.go new file mode 100644 index 00000000..81971cea --- /dev/null +++ b/arrow/compute/scalar_set_lookup.go @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "context" + "errors" + "fmt" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/compute/exec" + "github.com/apache/arrow-go/v18/arrow/compute/internal/kernels" + "github.com/apache/arrow-go/v18/arrow/extensions" + "github.com/apache/arrow-go/v18/internal/hashing" +) + +var ( + isinDoc = FunctionDoc{ + Summary: "Find each element in a set of values", + Description: `For each element in "values", return true if it is found +in a given set, false otherwise`, + ArgNames: []string{"values"}, + OptionsType: "SetOptions", + OptionsRequired: true, + } +) + +type NullMatchingBehavior = kernels.NullMatchingBehavior + +const ( + NullMatchingMatch = kernels.NullMatchingMatch + NullMatchingSkip = kernels.NullMatchingSkip + NullMatchingEmitNull = kernels.NullMatchingEmitNull + NullMatchingInconclusive = kernels.NullMatchingInconclusive +) + +type setLookupFunc struct { + ScalarFunction +} + +func (fn *setLookupFunc) Execute(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) { + return execInternal(ctx, fn, opts, -1, args...) +} + +func (fn *setLookupFunc) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) { + ensureDictionaryDecoded(vals...) + return fn.DispatchExact(vals...) +} + +type SetOptions struct { + ValueSet Datum + NullBehavior NullMatchingBehavior +} + +func (*SetOptions) TypeName() string { return "SetOptions" } + +func initSetLookup(ctx *exec.KernelCtx, args exec.KernelInitArgs) (exec.KernelState, error) { + if args.Options == nil { + return nil, fmt.Errorf("%w: calling a set lookup function without SetOptions", ErrInvalid) + } + + opts, ok := args.Options.(*SetOptions) + if !ok { + return nil, fmt.Errorf("%w: expected SetOptions, got %T", ErrInvalid, args.Options) + } + + valueset, ok := opts.ValueSet.(ArrayLikeDatum) + if !ok { + return nil, fmt.Errorf("%w: expected array-like datum, got %T", ErrInvalid, opts.ValueSet) + } + + argType := args.Inputs[0] + if (argType.ID() == arrow.STRING || argType.ID() == arrow.LARGE_STRING) && !arrow.IsBaseBinary(valueset.Type().ID()) { + // don't implicitly cast from a non-binary type to string + // since most types support casting to string and that may lead to + // surprises. However we do want most other implicit casts + return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", + ErrInvalid, argType, valueset.Type()) + } + + if !arrow.TypeEqual(valueset.Type(), argType) { + result, err := CastDatum(ctx.Ctx, valueset, SafeCastOptions(argType)) + if err == nil { + defer result.Release() + valueset = result.(ArrayLikeDatum) + } else if CanCast(argType, valueset.Type()) { + // avoid casting from non-binary types to string like above + // otherwise will try to cast input array to valueset during + // execution + if (valueset.Type().ID() == arrow.STRING || valueset.Type().ID() == arrow.LARGE_STRING) && !arrow.IsBaseBinary(argType.ID()) { + return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", + ErrInvalid, argType, valueset.Type()) + } + } else { + return nil, fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", + ErrInvalid, argType, valueset.Type()) + } + + } + + internalOpts := kernels.SetLookupOptions{ + ValueSet: make([]exec.ArraySpan, 1), + TotalLen: opts.ValueSet.Len(), + NullBehavior: opts.NullBehavior, + } + + switch valueset.Kind() { + case KindArray: + internalOpts.ValueSet[0].SetMembers(valueset.(*ArrayDatum).Value) + internalOpts.ValueSetType = valueset.(*ArrayDatum).Type() + case KindChunked: + chnked := valueset.(*ChunkedDatum).Value + internalOpts.ValueSetType = chnked.DataType() + internalOpts.ValueSet = make([]exec.ArraySpan, len(chnked.Chunks())) + for i, c := range chnked.Chunks() { + internalOpts.ValueSet[i].SetMembers(c.Data()) + } + default: + return nil, fmt.Errorf("%w: expected array or chunked array, got %s", ErrInvalid, opts.ValueSet.Kind()) + } + + return kernels.CreateSetLookupState(internalOpts, exec.GetAllocator(ctx.Ctx)) +} + +type setLookupState interface { + Init(kernels.SetLookupOptions) error + ValueType() arrow.DataType +} + +func execIsIn(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + state := ctx.State.(setLookupState) + ctx.Kernel.(*exec.ScalarKernel).Data = state + in := batch.Values[0] + + if !arrow.TypeEqual(in.Type(), state.ValueType()) { + materialized := in.Array.MakeArray() + defer materialized.Release() + + castResult, err := CastArray(ctx.Ctx, materialized, SafeCastOptions(state.ValueType())) + if err != nil { + if errors.Is(err, arrow.ErrNotImplemented) { + return fmt.Errorf("%w: array type doesn't match type of values set: %s vs %s", + ErrInvalid, in.Type(), state.ValueType()) + } + return err + } + defer castResult.Release() + + var casted exec.ArraySpan + casted.SetMembers(castResult.Data()) + return kernels.DispatchIsIn(state, &casted, out) + } + + return kernels.DispatchIsIn(state, &in.Array, out) +} + +func IsIn(ctx context.Context, opts SetOptions, values Datum) (Datum, error) { + return CallFunction(ctx, "is_in", &opts, values) +} + +func IsInSet(ctx context.Context, valueSet, values Datum) (Datum, error) { + return IsIn(ctx, SetOptions{ValueSet: valueSet}, values) +} + +func RegisterScalarSetLookup(reg FunctionRegistry) { + inBase := NewScalarFunction("is_in", Unary(), isinDoc) + + types := []exec.InputType{ + exec.NewMatchedInput(exec.Primitive()), + exec.NewIDInput(arrow.DECIMAL32), + exec.NewIDInput(arrow.DECIMAL64), + } + + outType := exec.NewOutputType(arrow.FixedWidthTypes.Boolean) + for _, ty := range types { + kn := exec.NewScalarKernel([]exec.InputType{ty}, outType, execIsIn, initSetLookup) + kn.MemAlloc = exec.MemPrealloc + kn.NullHandling = exec.NullComputedPrealloc + if err := inBase.AddKernel(kn); err != nil { + panic(err) + } + } + + binaryTypes := []exec.InputType{ + exec.NewMatchedInput(exec.BinaryLike()), + exec.NewMatchedInput(exec.LargeBinaryLike()), + exec.NewExactInput(extensions.NewUUIDType()), + exec.NewIDInput(arrow.FIXED_SIZE_BINARY), + exec.NewIDInput(arrow.DECIMAL128), + exec.NewIDInput(arrow.DECIMAL256), + } + for _, ty := range binaryTypes { + kn := exec.NewScalarKernel([]exec.InputType{ty}, outType, execIsIn, initSetLookup) + kn.MemAlloc = exec.MemPrealloc + kn.NullHandling = exec.NullComputedPrealloc + kn.CleanupFn = func(state exec.KernelState) error { + s := state.(*kernels.SetLookupState[[]byte]) + s.Lookup.(*hashing.BinaryMemoTable).Release() + return nil + } + + if err := inBase.AddKernel(kn); err != nil { + panic(err) + } + } + + reg.AddFunction(&setLookupFunc{*inBase}, false) +} diff --git a/arrow/compute/scalar_set_lookup_test.go b/arrow/compute/scalar_set_lookup_test.go new file mode 100644 index 00000000..770b984f --- /dev/null +++ b/arrow/compute/scalar_set_lookup_test.go @@ -0,0 +1,606 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute_test + +import ( + "context" + "strings" + "testing" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/compute" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/suite" +) + +type ScalarSetLookupSuite struct { + suite.Suite + + mem *memory.CheckedAllocator + ctx context.Context +} + +func (ss *ScalarSetLookupSuite) SetupTest() { + ss.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + ss.ctx = compute.WithAllocator(context.TODO(), ss.mem) +} + +func (ss *ScalarSetLookupSuite) getArr(dt arrow.DataType, str string) arrow.Array { + arr, _, err := array.FromJSON(ss.mem, dt, strings.NewReader(str), array.WithUseNumber()) + ss.Require().NoError(err) + return arr +} + +func (ss *ScalarSetLookupSuite) checkIsIn(input, valueSet arrow.Array, expectedJSON string, matching compute.NullMatchingBehavior) { + expected := ss.getArr(arrow.FixedWidthTypes.Boolean, expectedJSON) + defer expected.Release() + + result, err := compute.IsIn(ss.ctx, compute.SetOptions{ + ValueSet: compute.NewDatumWithoutOwning(valueSet), + NullBehavior: matching, + }, compute.NewDatumWithoutOwning(input)) + ss.Require().NoError(err) + defer result.Release() + + assertDatumsEqual(ss.T(), compute.NewDatumWithoutOwning(expected), result, nil, nil) +} + +func (ss *ScalarSetLookupSuite) checkIsInFromJSON(typ arrow.DataType, input, valueSet, expected string, matching compute.NullMatchingBehavior) { + inputArr := ss.getArr(typ, input) + defer inputArr.Release() + + valueSetArr := ss.getArr(typ, valueSet) + defer valueSetArr.Release() + + ss.checkIsIn(inputArr, valueSetArr, expected, matching) +} + +func (ss *ScalarSetLookupSuite) checkIsInDictionary(typ, idxType arrow.DataType, inputDict, inputIndex, valueSet, expected string, matching compute.NullMatchingBehavior) { + dictType := &arrow.DictionaryType{IndexType: idxType, ValueType: typ} + indices := ss.getArr(idxType, inputIndex) + defer indices.Release() + dict := ss.getArr(typ, inputDict) + defer dict.Release() + + input := array.NewDictionaryArray(dictType, indices, dict) + defer input.Release() + + valueSetArr := ss.getArr(typ, valueSet) + defer valueSetArr.Release() + + ss.checkIsIn(input, valueSetArr, expected, matching) +} + +func (ss *ScalarSetLookupSuite) checkIsInChunked(input, value, expected *arrow.Chunked, matching compute.NullMatchingBehavior) { + result, err := compute.IsIn(ss.ctx, compute.SetOptions{ + ValueSet: compute.NewDatumWithoutOwning(value), + NullBehavior: matching, + }, compute.NewDatumWithoutOwning(input)) + ss.Require().NoError(err) + defer result.Release() + + ss.Len(result.(*compute.ChunkedDatum).Chunks(), 1) + assertDatumsEqual(ss.T(), compute.NewDatumWithoutOwning(expected), result, nil, nil) +} + +func (ss *ScalarSetLookupSuite) TestIsInPrimitive() { + type testCase struct { + expected string + matching compute.NullMatchingBehavior + } + + tests := []struct { + name string + input string + valueset string + cases []testCase + }{ + {"no nulls", `[0, 1, 2, 3, 2]`, `[2, 1]`, []testCase{ + {`[false, true, true, false, true]`, compute.NullMatchingMatch}, + }}, + {"nulls in left", `[null, 1, 2, 3, 2]`, `[2, 1]`, []testCase{ + {`[false, true, true, false, true]`, compute.NullMatchingMatch}, + {`[false, true, true, false, true]`, compute.NullMatchingSkip}, + {`[null, true, true, false, true]`, compute.NullMatchingEmitNull}, + {`[null, true, true, false, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls in right", `[0, 1, 2, 3, 2]`, `[2, null, 1]`, []testCase{ + {`[false, true, true, false, true]`, compute.NullMatchingMatch}, + {`[false, true, true, false, true]`, compute.NullMatchingSkip}, + {`[false, true, true, false, true]`, compute.NullMatchingEmitNull}, + {`[null, true, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls in both", `[null, 1, 2, 3, 2]`, `[2, null, 1]`, []testCase{ + {`[true, true, true, false, true]`, compute.NullMatchingMatch}, + {`[false, true, true, false, true]`, compute.NullMatchingSkip}, + {`[null, true, true, false, true]`, compute.NullMatchingEmitNull}, + {`[null, true, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"duplicates in right", `[null, 1, 2, 3, 2]`, `[null, 2, 2, null, 1, 1]`, []testCase{ + {`[true, true, true, false, true]`, compute.NullMatchingMatch}, + {`[false, true, true, false, true]`, compute.NullMatchingSkip}, + {`[null, true, true, false, true]`, compute.NullMatchingEmitNull}, + {`[null, true, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"empty arrays", `[]`, `[]`, []testCase{ + {`[]`, compute.NullMatchingMatch}, + }}, + } + + typList := append([]arrow.DataType{}, numericTypes...) + typList = append(typList, arrow.FixedWidthTypes.Time32s, + arrow.FixedWidthTypes.Time32ms, arrow.FixedWidthTypes.Time64us, + arrow.FixedWidthTypes.Time64ns, arrow.FixedWidthTypes.Timestamp_us, + arrow.FixedWidthTypes.Timestamp_ns, arrow.FixedWidthTypes.Duration_s, + arrow.FixedWidthTypes.Duration_ms, arrow.FixedWidthTypes.Duration_us, + arrow.FixedWidthTypes.Duration_ns) + + for _, typ := range typList { + ss.Run(typ.String(), func() { + for _, tt := range tests { + ss.Run(tt.name, func() { + for _, tc := range tt.cases { + ss.checkIsInFromJSON(typ, + tt.input, tt.valueset, tc.expected, tc.matching) + } + }) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestDurationCasts() { + vals := ss.getArr(arrow.FixedWidthTypes.Duration_s, `[0, 1, 2]`) + defer vals.Release() + + valueset := ss.getArr(arrow.FixedWidthTypes.Duration_ms, `[1, 2, 2000]`) + defer valueset.Release() + + ss.checkIsIn(vals, valueset, `[false, false, true]`, compute.NullMatchingMatch) +} + +func (ss *ScalarSetLookupSuite) TestIsInBinary() { + type testCase struct { + expected string + matching compute.NullMatchingBehavior + } + + tests := []struct { + name string + input string + valueset string + cases []testCase + }{ + {"nulls on left", `["YWFh", "", "Y2M=", null, ""]`, `["YWFh", ""]`, []testCase{ + {`[true, true, false, false, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, false, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls on right", `["YWFh", "", "Y2M=", null, ""]`, `["YWFh", "", null]`, []testCase{ + {`[true, true, false, true, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, null, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"duplicates in right array", `["YWFh", "", "Y2M=", null, ""]`, `[null, "YWFh", "YWFh", "", "", null]`, []testCase{ + {`[true, true, false, true, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, null, null, true]`, compute.NullMatchingInconclusive}, + }}, + } + + for _, typ := range baseBinaryTypes { + ss.Run(typ.String(), func() { + for _, tt := range tests { + ss.Run(tt.name, func() { + for _, tc := range tt.cases { + ss.checkIsInFromJSON(typ, + tt.input, tt.valueset, tc.expected, tc.matching) + } + }) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestIsInFixedSizeBinary() { + type testCase struct { + expected string + matching compute.NullMatchingBehavior + } + + tests := []struct { + name string + input string + valueset string + cases []testCase + }{ + {"nulls on left", `["YWFh", "YmJi", "Y2Nj", null, "YmJi"]`, `["YWFh", "YmJi"]`, []testCase{ + {`[true, true, false, false, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, false, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls on right", `["YWFh", "YmJi", "Y2Nj", null, "YmJi"]`, `["YWFh", "YmJi", null]`, []testCase{ + {`[true, true, false, true, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, null, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"duplicates in right array", `["YWFh", "YmJi", "Y2Nj", null, "YmJi"]`, `["YWFh", null, "YWFh", "YmJi", "YmJi", null]`, []testCase{ + {`[true, true, false, true, true]`, compute.NullMatchingMatch}, + {`[true, true, false, false, true]`, compute.NullMatchingSkip}, + {`[true, true, false, null, true]`, compute.NullMatchingEmitNull}, + {`[true, true, null, null, true]`, compute.NullMatchingInconclusive}, + }}, + } + + typ := &arrow.FixedSizeBinaryType{ByteWidth: 3} + for _, tt := range tests { + ss.Run(tt.name, func() { + for _, tc := range tt.cases { + ss.checkIsInFromJSON(typ, + tt.input, tt.valueset, tc.expected, tc.matching) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestIsInDecimal() { + type testCase struct { + expected string + matching compute.NullMatchingBehavior + } + + tests := []struct { + name string + input string + valueset string + cases []testCase + }{ + {"nulls on left", `["12.3", "45.6", "78.9", null, "12.3"]`, `["12.3", "78.9"]`, []testCase{ + {`[true, false, true, false, true]`, compute.NullMatchingMatch}, + {`[true, false, true, false, true]`, compute.NullMatchingSkip}, + {`[true, false, true, null, true]`, compute.NullMatchingEmitNull}, + {`[true, false, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"nulls on right", `["12.3", "45.6", "78.9", null, "12.3"]`, `["12.3", "78.9", null]`, []testCase{ + {`[true, false, true, true, true]`, compute.NullMatchingMatch}, + {`[true, false, true, false, true]`, compute.NullMatchingSkip}, + {`[true, false, true, null, true]`, compute.NullMatchingEmitNull}, + {`[true, null, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + {"duplicates in right array", `["12.3", "45.6", "78.9", null, "12.3"]`, `[null, "12.3", "12.3", "78.9", "78.9", null]`, []testCase{ + {`[true, false, true, true, true]`, compute.NullMatchingMatch}, + {`[true, false, true, false, true]`, compute.NullMatchingSkip}, + {`[true, false, true, null, true]`, compute.NullMatchingEmitNull}, + {`[true, null, true, null, true]`, compute.NullMatchingInconclusive}, + }}, + } + + decTypes := []arrow.DataType{ + &arrow.Decimal32Type{Precision: 3, Scale: 1}, + &arrow.Decimal64Type{Precision: 3, Scale: 1}, + &arrow.Decimal128Type{Precision: 3, Scale: 1}, + &arrow.Decimal256Type{Precision: 3, Scale: 1}, + } + + for _, typ := range decTypes { + ss.Run(typ.String(), func() { + for _, tt := range tests { + ss.Run(tt.name, func() { + for _, tc := range tt.cases { + ss.checkIsInFromJSON(typ, + tt.input, tt.valueset, tc.expected, tc.matching) + } + }) + } + + // don't yet have Decimal32 or Decimal64 implemented for casting + if typ.ID() == arrow.DECIMAL128 || typ.ID() == arrow.DECIMAL256 { + // test cast + in := ss.getArr(&arrow.Decimal128Type{Precision: 4, Scale: 2}, `["12.30", "45.60", "78.90"]`) + defer in.Release() + values := ss.getArr(typ, `["12.3", "78.9"]`) + defer values.Release() + + ss.checkIsIn(in, values, `[true, false true]`, compute.NullMatchingMatch) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestIsInDictionary() { + tests := []struct { + typ arrow.DataType + inputDict string + inputIdx string + valueSet string + expected string + matching compute.NullMatchingBehavior + }{ + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 2, null, 0]`, + valueSet: `["A", "B", "C"]`, + expected: `[true, true, false, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.PrimitiveTypes.Float32, + inputDict: `[4.1, -1.0, 42, 9.8]`, + inputIdx: `[1, 2, null, 0]`, + valueSet: `[4.1, 42, -1.0]`, + expected: `[true, true, false, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, false, true, true, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, false, true, true, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A"]`, + expected: `[false, false, false, true, false]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, false, false, true, true]`, + matching: compute.NullMatchingSkip, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[false, false, false, true, false]`, + matching: compute.NullMatchingSkip, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A"]`, + expected: `[false, false, false, true, false]`, + matching: compute.NullMatchingSkip, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, false, null, true, true]`, + matching: compute.NullMatchingEmitNull, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[null, false, null, true, null]`, + matching: compute.NullMatchingEmitNull, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A"]`, + expected: `[null, false, null, true, null]`, + matching: compute.NullMatchingEmitNull, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[true, null, null, true, true]`, + matching: compute.NullMatchingInconclusive, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A", null]`, + expected: `[null, null, null, true, null]`, + matching: compute.NullMatchingInconclusive, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", null, "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "B", "A"]`, + expected: `[null, false, null, true, null]`, + matching: compute.NullMatchingInconclusive, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 2, null, 0]`, + valueSet: `["A", "A", "B", "A", "B", "C"]`, + expected: `[true, true, false, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "C", "B", "A", null, null, "B"]`, + expected: `[true, false, true, true, true]`, + matching: compute.NullMatchingMatch, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "C", "B", "A", null, null, "B"]`, + expected: `[true, false, false, true, true]`, + matching: compute.NullMatchingSkip, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "C", "B", "A", null, null, "B"]`, + expected: `[true, false, null, true, true]`, + matching: compute.NullMatchingEmitNull, + }, + { + typ: arrow.BinaryTypes.String, + inputDict: `["A", "B", "C", "D"]`, + inputIdx: `[1, 3, null, 0, 1]`, + valueSet: `["C", "C", "B", "A", null, null, "B"]`, + expected: `[true, null, null, true, true]`, + matching: compute.NullMatchingInconclusive, + }, + } + + for _, ty := range dictIndexTypes { + ss.Run("idx="+ty.String(), func() { + for _, test := range tests { + ss.Run(test.typ.String(), func() { + ss.checkIsInDictionary(test.typ, ty, + test.inputDict, test.inputIdx, test.valueSet, + test.expected, test.matching) + }) + } + }) + } +} + +func (ss *ScalarSetLookupSuite) TestIsInChunked() { + input, err := array.ChunkedFromJSON(ss.mem, arrow.BinaryTypes.String, + []string{`["abc", "def", "", "abc", "jkl"]`, `["def", null, "abc", "zzz"]`}) + ss.Require().NoError(err) + defer input.Release() + + valueSet, err := array.ChunkedFromJSON(ss.mem, arrow.BinaryTypes.String, + []string{`["", "def"]`, `["abc"]`}) + ss.Require().NoError(err) + defer valueSet.Release() + + expected, err := array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[true, true, true, true, false]`, `[true, false, true, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingMatch) + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingSkip) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[true, true, true, true, false]`, `[true, null, true, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingEmitNull) + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingInconclusive) + + valueSet, err = array.ChunkedFromJSON(ss.mem, arrow.BinaryTypes.String, + []string{`["", "def"]`, `[null]`}) + ss.Require().NoError(err) + defer valueSet.Release() + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, true, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingMatch) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, false, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingSkip) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, null, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingEmitNull) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[null, true, true, null, null]`, `[true, null, null, null]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingInconclusive) + + valueSet, err = array.ChunkedFromJSON(ss.mem, arrow.BinaryTypes.String, + []string{`["", null, "", "def"]`, `["def", null]`}) + ss.Require().NoError(err) + defer valueSet.Release() + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, true, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingMatch) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, false, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingSkip) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[false, true, true, false, false]`, `[true, null, false, false]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingEmitNull) + + expected, err = array.ChunkedFromJSON(ss.mem, arrow.FixedWidthTypes.Boolean, + []string{`[null, true, true, null, null]`, `[true, null, null, null]`}) + ss.Require().NoError(err) + defer expected.Release() + + ss.checkIsInChunked(input, valueSet, expected, compute.NullMatchingInconclusive) +} + +func (ss *ScalarSetLookupSuite) TearDownTest() { + ss.mem.AssertSize(ss.T(), 0) +} + +func TestScalarSetLookup(t *testing.T) { + suite.Run(t, new(ScalarSetLookupSuite)) +} diff --git a/arrow/csv/reader.go b/arrow/csv/reader.go index dd0c0f18..db0f836d 100644 --- a/arrow/csv/reader.go +++ b/arrow/csv/reader.go @@ -43,7 +43,7 @@ type Reader struct { r *csv.Reader schema *arrow.Schema - refs int64 + refs atomic.Int64 bld *array.RecordBuilder cur arrow.Record err error @@ -75,10 +75,10 @@ type Reader struct { func NewInferringReader(r io.Reader, opts ...Option) *Reader { rr := &Reader{ r: csv.NewReader(r), - refs: 1, chunk: 1, stringsCanBeNull: false, } + rr.refs.Add(1) rr.r.ReuseRecord = true for _, opt := range opts { opt(rr) @@ -111,10 +111,10 @@ func NewReader(r io.Reader, schema *arrow.Schema, opts ...Option) *Reader { rr := &Reader{ r: csv.NewReader(r), schema: schema, - refs: 1, chunk: 1, stringsCanBeNull: false, } + rr.refs.Add(1) rr.r.ReuseRecord = true for _, opt := range opts { opt(rr) @@ -288,9 +288,7 @@ func (r *Reader) nextall() bool { r.done = true }() - var ( - recs [][]string - ) + var recs [][]string recs, r.err = r.r.ReadAll() if r.err != nil { @@ -926,16 +924,16 @@ func (r *Reader) parseExtension(field array.Builder, str string) { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (r *Reader) Retain() { - atomic.AddInt64(&r.refs, 1) + r.refs.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (r *Reader) Release() { - debug.Assert(atomic.LoadInt64(&r.refs) > 0, "too many releases") + debug.Assert(r.refs.Load() > 0, "too many releases") - if atomic.AddInt64(&r.refs, -1) == 0 { + if r.refs.Add(-1) == 0 { if r.cur != nil { r.cur.Release() } @@ -1025,6 +1023,4 @@ func tryParse(val string, dt arrow.DataType) error { panic("shouldn't end up here") } -var ( - _ array.RecordReader = (*Reader)(nil) -) +var _ array.RecordReader = (*Reader)(nil) diff --git a/arrow/doc.go b/arrow/doc.go index 690a4f53..7bc175c0 100644 --- a/arrow/doc.go +++ b/arrow/doc.go @@ -34,7 +34,7 @@ To build with tinygo include the noasm build tag. */ package arrow -const PkgVersion = "18.2.0" +const PkgVersion = "18.3.0" //go:generate go run _tools/tmpl/main.go -i -data=numeric.tmpldata type_traits_numeric.gen.go.tmpl type_traits_numeric.gen_test.go.tmpl array/numeric.gen.go.tmpl array/numericbuilder.gen.go.tmpl array/bufferbuilder_numeric.gen.go.tmpl //go:generate go run _tools/tmpl/main.go -i -data=datatype_numeric.gen.go.tmpldata datatype_numeric.gen.go.tmpl tensor/numeric.gen.go.tmpl tensor/numeric.gen_test.go.tmpl diff --git a/arrow/flight/flightsql/column_metadata.go b/arrow/flight/flightsql/column_metadata.go index d46fab30..10163aa5 100644 --- a/arrow/flight/flightsql/column_metadata.go +++ b/arrow/flight/flightsql/column_metadata.go @@ -50,6 +50,7 @@ const ( IsCaseSensitiveKey = "ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE" IsReadOnlyKey = "ARROW:FLIGHT:SQL:IS_READ_ONLY" IsSearchableKey = "ARROW:FLIGHT:SQL:IS_SEARCHABLE" + RemarksKey = "ARROW:FLIGHT:SQL:REMARKS" ) // ColumnMetadata is a helper object for managing and querying the @@ -130,6 +131,10 @@ func (c *ColumnMetadata) IsSearchable() (bool, bool) { return c.findBoolVal(IsSearchableKey) } +func (c *ColumnMetadata) Remarks() (string, bool) { + return c.findStrVal(RemarksKey) +} + // ColumnMetadataBuilder is a convenience builder for constructing // sql column metadata using the expected standard metadata keys. // All methods return the builder itself so it can be chained @@ -215,3 +220,9 @@ func (c *ColumnMetadataBuilder) IsSearchable(v bool) *ColumnMetadataBuilder { c.vals = append(c.vals, boolToStr(v)) return c } + +func (c *ColumnMetadataBuilder) Remarks(remarks string) *ColumnMetadataBuilder { + c.keys = append(c.keys, RemarksKey) + c.vals = append(c.vals, remarks) + return c +} diff --git a/arrow/flight/flightsql/example/sql_batch_reader.go b/arrow/flight/flightsql/example/sql_batch_reader.go index 5a6cafe6..e0892e0a 100644 --- a/arrow/flight/flightsql/example/sql_batch_reader.go +++ b/arrow/flight/flightsql/example/sql_batch_reader.go @@ -99,7 +99,7 @@ func getArrowType(c *sql.ColumnType) arrow.DataType { const maxBatchSize = 1024 type SqlBatchReader struct { - refCount int64 + refCount atomic.Int64 schema *arrow.Schema rows *sql.Rows @@ -152,12 +152,15 @@ func NewSqlBatchReaderWithSchema(mem memory.Allocator, schema *arrow.Schema, row } } - return &SqlBatchReader{ - refCount: 1, - bldr: array.NewRecordBuilder(mem, schema), - schema: schema, - rowdest: rowdest, - rows: rows}, nil + sqb := &SqlBatchReader{ + bldr: array.NewRecordBuilder(mem, schema), + schema: schema, + rowdest: rowdest, + rows: rows, + } + + sqb.refCount.Add(1) + return sqb, nil } func NewSqlBatchReader(mem memory.Allocator, rows *sql.Rows) (*SqlBatchReader, error) { @@ -219,22 +222,24 @@ func NewSqlBatchReader(mem memory.Allocator, rows *sql.Rows) (*SqlBatchReader, e } schema := arrow.NewSchema(fields, nil) - return &SqlBatchReader{ - refCount: 1, - bldr: array.NewRecordBuilder(mem, schema), - schema: schema, - rowdest: rowdest, - rows: rows}, nil + sbr := &SqlBatchReader{ + bldr: array.NewRecordBuilder(mem, schema), + schema: schema, + rowdest: rowdest, + rows: rows, + } + sbr.refCount.Add(1) + return sbr, nil } func (r *SqlBatchReader) Retain() { - atomic.AddInt64(&r.refCount, 1) + r.refCount.Add(1) } func (r *SqlBatchReader) Release() { - debug.Assert(atomic.LoadInt64(&r.refCount) > 0, "too many releases") + debug.Assert(r.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&r.refCount, -1) == 0 { + if r.refCount.Add(-1) == 0 { r.rows.Close() r.rows, r.schema, r.rowdest = nil, nil, nil r.bldr.Release() diff --git a/arrow/flight/flightsql/example/sqlite_tables_schema_batch_reader.go b/arrow/flight/flightsql/example/sqlite_tables_schema_batch_reader.go index 3009635f..55b23903 100644 --- a/arrow/flight/flightsql/example/sqlite_tables_schema_batch_reader.go +++ b/arrow/flight/flightsql/example/sqlite_tables_schema_batch_reader.go @@ -35,7 +35,7 @@ import ( ) type SqliteTablesSchemaBatchReader struct { - refCount int64 + refCount atomic.Int64 mem memory.Allocator ctx context.Context @@ -57,24 +57,25 @@ func NewSqliteTablesSchemaBatchReader(ctx context.Context, mem memory.Allocator, return nil, err } - return &SqliteTablesSchemaBatchReader{ - refCount: 1, + stsbr := &SqliteTablesSchemaBatchReader{ ctx: ctx, rdr: rdr, stmt: stmt, mem: mem, schemaBldr: array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary), - }, nil + } + stsbr.refCount.Add(1) + return stsbr, nil } func (s *SqliteTablesSchemaBatchReader) Err() error { return s.err } -func (s *SqliteTablesSchemaBatchReader) Retain() { atomic.AddInt64(&s.refCount, 1) } +func (s *SqliteTablesSchemaBatchReader) Retain() { s.refCount.Add(1) } func (s *SqliteTablesSchemaBatchReader) Release() { - debug.Assert(atomic.LoadInt64(&s.refCount) > 0, "too many releases") + debug.Assert(s.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&s.refCount, -1) == 0 { + if s.refCount.Add(-1) == 0 { s.rdr.Release() s.stmt.Close() s.schemaBldr.Release() diff --git a/arrow/flight/record_batch_reader.go b/arrow/flight/record_batch_reader.go index c6596e82..c65d89f3 100644 --- a/arrow/flight/record_batch_reader.go +++ b/arrow/flight/record_batch_reader.go @@ -40,7 +40,7 @@ type dataMessageReader struct { rdr DataStreamReader peeked *FlightData - refCount int64 + refCount atomic.Int64 msg *ipc.Message lastAppMetadata []byte @@ -78,13 +78,13 @@ func (d *dataMessageReader) Message() (*ipc.Message, error) { } func (d *dataMessageReader) Retain() { - atomic.AddInt64(&d.refCount, 1) + d.refCount.Add(1) } func (d *dataMessageReader) Release() { - debug.Assert(atomic.LoadInt64(&d.refCount) > 0, "too many releases") + debug.Assert(d.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&d.refCount, -1) == 0 { + if d.refCount.Add(-1) == 0 { if d.msg != nil { d.msg.Release() d.msg = nil @@ -154,7 +154,8 @@ func NewRecordReader(r DataStreamReader, opts ...ipc.Option) (*Reader, error) { return nil, err } - rdr := &Reader{dmr: &dataMessageReader{rdr: r, refCount: 1}} + rdr := &Reader{dmr: &dataMessageReader{rdr: r}} + rdr.dmr.refCount.Add(1) rdr.dmr.descr = data.FlightDescriptor if len(data.DataHeader) > 0 { rdr.dmr.peeked = data diff --git a/arrow/internal/arrjson/reader.go b/arrow/internal/arrjson/reader.go index 0db94740..ec021fc1 100644 --- a/arrow/internal/arrjson/reader.go +++ b/arrow/internal/arrjson/reader.go @@ -28,7 +28,7 @@ import ( ) type Reader struct { - refs int64 + refs atomic.Int64 schema *arrow.Schema recs []arrow.Record @@ -55,27 +55,27 @@ func NewReader(r io.Reader, opts ...Option) (*Reader, error) { schema := schemaFromJSON(raw.Schema, &memo) dictionariesFromJSON(cfg.alloc, raw.Dictionaries, &memo) rr := &Reader{ - refs: 1, schema: schema, recs: recordsFromJSON(cfg.alloc, schema, raw.Records, &memo), memo: &memo, } + rr.refs.Add(1) return rr, nil } // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (r *Reader) Retain() { - atomic.AddInt64(&r.refs, 1) + r.refs.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (r *Reader) Release() { - debug.Assert(atomic.LoadInt64(&r.refs) > 0, "too many releases") + debug.Assert(r.refs.Load() > 0, "too many releases") - if atomic.AddInt64(&r.refs, -1) == 0 { + if r.refs.Add(-1) == 0 { for i, rec := range r.recs { if r.recs[i] != nil { rec.Release() @@ -106,6 +106,4 @@ func (r *Reader) ReadAt(index int) (arrow.Record, error) { return rec, nil } -var ( - _ arrio.Reader = (*Reader)(nil) -) +var _ arrio.Reader = (*Reader)(nil) diff --git a/arrow/internal/flight_integration/scenario.go b/arrow/internal/flight_integration/scenario.go index 20c63003..a640d06b 100644 --- a/arrow/internal/flight_integration/scenario.go +++ b/arrow/internal/flight_integration/scenario.go @@ -2228,6 +2228,7 @@ func getQuerySchema() *arrow.Schema { IsSearchable(true). CatalogName("catalog_test"). Precision(100). + Remarks("test column"). Build().Data}}, nil) } @@ -2242,6 +2243,7 @@ func getQueryWithTransactionSchema() *arrow.Schema { SchemaName("schema_test"). IsSearchable(true). CatalogName("catalog_test"). + Remarks("test column"). Precision(100).Build().Data}}, nil) } diff --git a/arrow/ipc/message.go b/arrow/ipc/message.go index f989cf2a..c96869ec 100644 --- a/arrow/ipc/message.go +++ b/arrow/ipc/message.go @@ -66,7 +66,7 @@ func (m MessageType) String() string { // Message is an IPC message, including metadata and body. type Message struct { - refCount int64 + refCount atomic.Int64 msg *flatbuf.Message meta *memory.Buffer body *memory.Buffer @@ -80,12 +80,13 @@ func NewMessage(meta, body *memory.Buffer) *Message { } meta.Retain() body.Retain() - return &Message{ - refCount: 1, - msg: flatbuf.GetRootAsMessage(meta.Bytes(), 0), - meta: meta, - body: body, + m := &Message{ + msg: flatbuf.GetRootAsMessage(meta.Bytes(), 0), + meta: meta, + body: body, } + m.refCount.Add(1) + return m } func newMessageFromFB(meta *flatbuf.Message, body *memory.Buffer) *Message { @@ -93,27 +94,28 @@ func newMessageFromFB(meta *flatbuf.Message, body *memory.Buffer) *Message { panic("arrow/ipc: nil buffers") } body.Retain() - return &Message{ - refCount: 1, - msg: meta, - meta: memory.NewBufferBytes(meta.Table().Bytes), - body: body, + m := &Message{ + msg: meta, + meta: memory.NewBufferBytes(meta.Table().Bytes), + body: body, } + m.refCount.Add(1) + return m } // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (msg *Message) Retain() { - atomic.AddInt64(&msg.refCount, 1) + msg.refCount.Add(1) } // Release decreases the reference count by 1. // Release may be called simultaneously from multiple goroutines. // When the reference count goes to zero, the memory is freed. func (msg *Message) Release() { - debug.Assert(atomic.LoadInt64(&msg.refCount) > 0, "too many releases") + debug.Assert(msg.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&msg.refCount, -1) == 0 { + if msg.refCount.Add(-1) == 0 { msg.meta.Release() msg.body.Release() msg.msg = nil @@ -144,7 +146,7 @@ type MessageReader interface { type messageReader struct { r io.Reader - refCount int64 + refCount atomic.Int64 msg *Message mem memory.Allocator @@ -157,22 +159,24 @@ func NewMessageReader(r io.Reader, opts ...Option) MessageReader { opt(cfg) } - return &messageReader{r: r, refCount: 1, mem: cfg.alloc} + mr := &messageReader{r: r, mem: cfg.alloc} + mr.refCount.Add(1) + return mr } // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (r *messageReader) Retain() { - atomic.AddInt64(&r.refCount, 1) + r.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (r *messageReader) Release() { - debug.Assert(atomic.LoadInt64(&r.refCount) > 0, "too many releases") + debug.Assert(r.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&r.refCount, -1) == 0 { + if r.refCount.Add(-1) == 0 { if r.msg != nil { r.msg.Release() r.msg = nil @@ -184,7 +188,7 @@ func (r *messageReader) Release() { // underlying stream. // It is valid until the next call to Message. func (r *messageReader) Message() (*Message, error) { - var buf = make([]byte, 4) + buf := make([]byte, 4) _, err := io.ReadFull(r.r, buf) if err != nil { return nil, fmt.Errorf("arrow/ipc: could not read continuation indicator: %w", err) diff --git a/arrow/ipc/reader.go b/arrow/ipc/reader.go index 2a4f859b..1934c719 100644 --- a/arrow/ipc/reader.go +++ b/arrow/ipc/reader.go @@ -39,7 +39,7 @@ type Reader struct { r MessageReader schema *arrow.Schema - refCount int64 + refCount atomic.Int64 rec arrow.Record err error @@ -70,13 +70,14 @@ func NewReaderFromMessageReader(r MessageReader, opts ...Option) (reader *Reader rr := &Reader{ r: r, - refCount: 1, + refCount: atomic.Int64{}, // types: make(dictTypeMap), memo: dictutils.NewMemo(), mem: cfg.alloc, ensureNativeEndian: cfg.ensureNativeEndian, expectedSchema: cfg.schema, } + rr.refCount.Add(1) if !cfg.noAutoSchema { if err := rr.readSchema(cfg.schema); err != nil { @@ -141,16 +142,16 @@ func (r *Reader) readSchema(schema *arrow.Schema) error { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (r *Reader) Retain() { - atomic.AddInt64(&r.refCount, 1) + r.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (r *Reader) Release() { - debug.Assert(atomic.LoadInt64(&r.refCount) > 0, "too many releases") + debug.Assert(r.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&r.refCount, -1) == 0 { + if r.refCount.Add(-1) == 0 { if r.rec != nil { r.rec.Release() r.rec = nil @@ -280,6 +281,4 @@ func (r *Reader) Read() (arrow.Record, error) { return r.rec, nil } -var ( - _ array.RecordReader = (*Reader)(nil) -) +var _ array.RecordReader = (*Reader)(nil) diff --git a/arrow/memory/buffer.go b/arrow/memory/buffer.go index 04722225..592da70c 100644 --- a/arrow/memory/buffer.go +++ b/arrow/memory/buffer.go @@ -24,7 +24,7 @@ import ( // Buffer is a wrapper type for a buffer of bytes. type Buffer struct { - refCount int64 + refCount atomic.Int64 buf []byte length int mutable bool @@ -42,22 +42,28 @@ type Buffer struct { // through the c data interface and tracking the lifetime of the // imported buffers. func NewBufferWithAllocator(data []byte, mem Allocator) *Buffer { - return &Buffer{refCount: 1, buf: data, length: len(data), mem: mem} + b := &Buffer{buf: data, length: len(data), mem: mem} + b.refCount.Add(1) + return b } // NewBufferBytes creates a fixed-size buffer from the specified data. func NewBufferBytes(data []byte) *Buffer { - return &Buffer{refCount: 0, buf: data, length: len(data)} + return &Buffer{buf: data, length: len(data)} } // NewResizableBuffer creates a mutable, resizable buffer with an Allocator for managing memory. func NewResizableBuffer(mem Allocator) *Buffer { - return &Buffer{refCount: 1, mutable: true, mem: mem} + b := &Buffer{mutable: true, mem: mem} + b.refCount.Add(1) + return b } func SliceBuffer(buf *Buffer, offset, length int) *Buffer { buf.Retain() - return &Buffer{refCount: 1, parent: buf, buf: buf.Bytes()[offset : offset+length], length: length} + b := &Buffer{parent: buf, buf: buf.Bytes()[offset : offset+length], length: length} + b.refCount.Add(1) + return b } // Parent returns either nil or a pointer to the parent buffer if this buffer @@ -67,7 +73,7 @@ func (b *Buffer) Parent() *Buffer { return b.parent } // Retain increases the reference count by 1. func (b *Buffer) Retain() { if b.mem != nil || b.parent != nil { - atomic.AddInt64(&b.refCount, 1) + b.refCount.Add(1) } } @@ -75,9 +81,9 @@ func (b *Buffer) Retain() { // When the reference count goes to zero, the memory is freed. func (b *Buffer) Release() { if b.mem != nil || b.parent != nil { - debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases") + debug.Assert(b.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&b.refCount, -1) == 0 { + if b.refCount.Add(-1) == 0 { if b.mem != nil { b.mem.Free(b.buf) } else { diff --git a/arrow/memory/checked_allocator.go b/arrow/memory/checked_allocator.go index 78a09a57..103a0853 100644 --- a/arrow/memory/checked_allocator.go +++ b/arrow/memory/checked_allocator.go @@ -32,7 +32,7 @@ import ( type CheckedAllocator struct { mem Allocator - sz int64 + sz atomic.Int64 allocs sync.Map } @@ -41,10 +41,10 @@ func NewCheckedAllocator(mem Allocator) *CheckedAllocator { return &CheckedAllocator{mem: mem} } -func (a *CheckedAllocator) CurrentAlloc() int { return int(atomic.LoadInt64(&a.sz)) } +func (a *CheckedAllocator) CurrentAlloc() int { return int(a.sz.Load()) } func (a *CheckedAllocator) Allocate(size int) []byte { - atomic.AddInt64(&a.sz, int64(size)) + a.sz.Add(int64(size)) out := a.mem.Allocate(size) if size == 0 { return out @@ -66,7 +66,7 @@ func (a *CheckedAllocator) Allocate(size int) []byte { } func (a *CheckedAllocator) Reallocate(size int, b []byte) []byte { - atomic.AddInt64(&a.sz, int64(size-len(b))) + a.sz.Add(int64(size - len(b))) oldptr := uintptr(unsafe.Pointer(&b[0])) out := a.mem.Reallocate(size, b) @@ -92,7 +92,7 @@ func (a *CheckedAllocator) Reallocate(size int, b []byte) []byte { } func (a *CheckedAllocator) Free(b []byte) { - atomic.AddInt64(&a.sz, int64(len(b)*-1)) + a.sz.Add(int64(len(b) * -1)) defer a.mem.Free(b) if len(b) == 0 { @@ -192,9 +192,9 @@ func (a *CheckedAllocator) AssertSize(t TestingT, sz int) { return true }) - if int(atomic.LoadInt64(&a.sz)) != sz { + if int(a.sz.Load()) != sz { t.Helper() - t.Errorf("invalid memory size exp=%d, got=%d", sz, a.sz) + t.Errorf("invalid memory size exp=%d, got=%d", sz, a.sz.Load()) } } @@ -204,18 +204,16 @@ type CheckedAllocatorScope struct { } func NewCheckedAllocatorScope(alloc *CheckedAllocator) *CheckedAllocatorScope { - sz := atomic.LoadInt64(&alloc.sz) + sz := alloc.sz.Load() return &CheckedAllocatorScope{alloc: alloc, sz: int(sz)} } func (c *CheckedAllocatorScope) CheckSize(t TestingT) { - sz := int(atomic.LoadInt64(&c.alloc.sz)) + sz := int(c.alloc.sz.Load()) if c.sz != sz { t.Helper() t.Errorf("invalid memory size exp=%d, got=%d", c.sz, sz) } } -var ( - _ Allocator = (*CheckedAllocator)(nil) -) +var _ Allocator = (*CheckedAllocator)(nil) diff --git a/arrow/scalar/append.go b/arrow/scalar/append.go index 0525bc81..737e800d 100644 --- a/arrow/scalar/append.go +++ b/arrow/scalar/append.go @@ -76,41 +76,32 @@ func appendBinary(bldr binaryBuilder, scalars []Scalar) { } } -// Append requires the passed in builder and scalar to have the same datatype -// otherwise it will return an error. Will return arrow.ErrNotImplemented if -// the type hasn't been implemented for this. -// -// NOTE only available in go1.18+ -func Append(bldr array.Builder, s Scalar) error { - return AppendSlice(bldr, []Scalar{s}) +type extbuilder interface { + array.Builder + StorageBuilder() array.Builder } -// AppendSlice requires the passed in builder and all scalars in the slice -// to have the same datatype otherwise it will return an error. Will return -// arrow.ErrNotImplemented if the type hasn't been implemented for this. -// -// NOTE only available in go1.18+ -func AppendSlice(bldr array.Builder, scalars []Scalar) error { +func appendToBldr(bldr array.Builder, scalars []Scalar) error { if len(scalars) == 0 { return nil } ty := bldr.Type() - for _, sc := range scalars { - if !arrow.TypeEqual(ty, sc.DataType()) { - return fmt.Errorf("%w: cannot append scalar of type %s to builder for type %s", - arrow.ErrInvalid, scalars[0].DataType(), bldr.Type()) - } - } - bldr.Reserve(len(scalars)) switch bldr := bldr.(type) { + case extbuilder: + baseScalars := make([]Scalar, len(scalars)) + for i, sc := range scalars { + baseScalars[i] = sc.(*Extension).Value + } + + return appendToBldr(bldr.StorageBuilder(), baseScalars) case *array.BooleanBuilder: - appendPrimitive[bool](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Decimal128Builder: - appendPrimitive[decimal128.Num](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Decimal256Builder: - appendPrimitive[decimal256.Num](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.FixedSizeBinaryBuilder: for _, sc := range scalars { s := sc.(*FixedSizeBinary) @@ -121,45 +112,45 @@ func AppendSlice(bldr array.Builder, scalars []Scalar) error { } } case *array.Int8Builder: - appendPrimitive[int8](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Uint8Builder: - appendPrimitive[uint8](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Int16Builder: - appendPrimitive[int16](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Uint16Builder: - appendPrimitive[uint16](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Int32Builder: - appendPrimitive[int32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Uint32Builder: - appendPrimitive[uint32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Int64Builder: - appendPrimitive[int64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Uint64Builder: - appendPrimitive[uint64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Float16Builder: - appendPrimitive[float16.Num](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Float32Builder: - appendPrimitive[float32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Float64Builder: - appendPrimitive[float64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Date32Builder: - appendPrimitive[arrow.Date32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Date64Builder: - appendPrimitive[arrow.Date64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Time32Builder: - appendPrimitive[arrow.Time32](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.Time64Builder: - appendPrimitive[arrow.Time64](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.DayTimeIntervalBuilder: - appendPrimitive[arrow.DayTimeInterval](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.MonthIntervalBuilder: - appendPrimitive[arrow.MonthInterval](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.MonthDayNanoIntervalBuilder: - appendPrimitive[arrow.MonthDayNanoInterval](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.DurationBuilder: - appendPrimitive[arrow.Duration](bldr, scalars) + appendPrimitive(bldr, scalars) case *array.TimestampBuilder: - appendPrimitive[arrow.Timestamp](bldr, scalars) + appendPrimitive(bldr, scalars) case array.StringLikeBuilder: appendBinary(bldr, scalars) case *array.BinaryBuilder: @@ -261,3 +252,33 @@ func AppendSlice(bldr array.Builder, scalars []Scalar) error { return nil } + +// Append requires the passed in builder and scalar to have the same datatype +// otherwise it will return an error. Will return arrow.ErrNotImplemented if +// the type hasn't been implemented for this. +// +// NOTE only available in go1.18+ +func Append(bldr array.Builder, s Scalar) error { + return AppendSlice(bldr, []Scalar{s}) +} + +// AppendSlice requires the passed in builder and all scalars in the slice +// to have the same datatype otherwise it will return an error. Will return +// arrow.ErrNotImplemented if the type hasn't been implemented for this. +// +// NOTE only available in go1.18+ +func AppendSlice(bldr array.Builder, scalars []Scalar) error { + if len(scalars) == 0 { + return nil + } + + ty := bldr.Type() + for _, sc := range scalars { + if !arrow.TypeEqual(ty, sc.DataType()) { + return fmt.Errorf("%w: cannot append scalar of type %s to builder for type %s", + arrow.ErrInvalid, scalars[0].DataType(), bldr.Type()) + } + } + + return appendToBldr(bldr, scalars) +} diff --git a/arrow/table.go b/arrow/table.go index 6d19d9f1..bdbf85bf 100644 --- a/arrow/table.go +++ b/arrow/table.go @@ -79,16 +79,17 @@ func NewColumnFromArr(field Field, arr Array) Column { } arr.Retain() - return Column{ + col := Column{ field: field, data: &Chunked{ - refCount: 1, - chunks: []Array{arr}, - length: arr.Len(), - nulls: arr.NullN(), - dtype: field.Type, + chunks: []Array{arr}, + length: arr.Len(), + nulls: arr.NullN(), + dtype: field.Type, }, } + col.data.refCount.Add(1) + return col } // NewColumn returns a column from a field and a chunked data array. @@ -132,7 +133,7 @@ func (col *Column) DataType() DataType { return col.field.Type } // Chunked manages a collection of primitives arrays as one logical large array. type Chunked struct { - refCount int64 // refCount must be first in the struct for 64 bit alignment and sync/atomic (https://github.com/golang/go/issues/37262) + refCount atomic.Int64 chunks []Array @@ -146,10 +147,11 @@ type Chunked struct { // NewChunked panics if the chunks do not have the same data type. func NewChunked(dtype DataType, chunks []Array) *Chunked { arr := &Chunked{ - chunks: make([]Array, 0, len(chunks)), - refCount: 1, - dtype: dtype, + chunks: make([]Array, 0, len(chunks)), + dtype: dtype, } + arr.refCount.Add(1) + for _, chunk := range chunks { if chunk == nil { continue @@ -169,16 +171,16 @@ func NewChunked(dtype DataType, chunks []Array) *Chunked { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (a *Chunked) Retain() { - atomic.AddInt64(&a.refCount, 1) + a.refCount.Add(1) } // Release decreases the reference count by 1. // When the reference count goes to zero, the memory is freed. // Release may be called simultaneously from multiple goroutines. func (a *Chunked) Release() { - debug.Assert(atomic.LoadInt64(&a.refCount) > 0, "too many releases") + debug.Assert(a.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&a.refCount, -1) == 0 { + if a.refCount.Add(-1) == 0 { for _, arr := range a.chunks { arr.Release() } diff --git a/arrow/tensor/tensor.go b/arrow/tensor/tensor.go index b3bdf32c..70bbe572 100644 --- a/arrow/tensor/tensor.go +++ b/arrow/tensor/tensor.go @@ -65,7 +65,7 @@ type Interface interface { } type tensorBase struct { - refCount int64 + refCount atomic.Int64 dtype arrow.DataType bw int64 // bytes width data arrow.ArrayData @@ -77,16 +77,16 @@ type tensorBase struct { // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (tb *tensorBase) Retain() { - atomic.AddInt64(&tb.refCount, 1) + tb.refCount.Add(1) } // Release decreases the reference count by 1. // Release may be called simultaneously from multiple goroutines. // When the reference count goes to zero, the memory is freed. func (tb *tensorBase) Release() { - debug.Assert(atomic.LoadInt64(&tb.refCount) > 0, "too many releases") + debug.Assert(tb.refCount.Load() > 0, "too many releases") - if atomic.AddInt64(&tb.refCount, -1) == 0 { + if tb.refCount.Add(-1) == 0 { tb.data.Release() tb.data = nil } @@ -172,14 +172,14 @@ func New(data arrow.ArrayData, shape, strides []int64, names []string) Interface func newTensor(dtype arrow.DataType, data arrow.ArrayData, shape, strides []int64, names []string) *tensorBase { tb := tensorBase{ - refCount: 1, - dtype: dtype, - bw: int64(dtype.(arrow.FixedWidthDataType).BitWidth()) / 8, - data: data, - shape: shape, - strides: strides, - names: names, + dtype: dtype, + bw: int64(dtype.(arrow.FixedWidthDataType).BitWidth()) / 8, + data: data, + shape: shape, + strides: strides, + names: names, } + tb.refCount.Add(1) tb.data.Retain() if len(tb.shape) > 0 && len(tb.strides) == 0 { diff --git a/arrow/tmpl b/arrow/tmpl new file mode 100755 index 00000000..60df161a Binary files /dev/null and b/arrow/tmpl differ diff --git a/ci/docker/debian-12.dockerfile b/ci/docker/debian-12.dockerfile index bcf0984d..415b1d20 100644 --- a/ci/docker/debian-12.dockerfile +++ b/ci/docker/debian-12.dockerfile @@ -16,9 +16,9 @@ # under the License. ARG arch=amd64 -ARG go=1.22.6 +ARG go=1.23 FROM ${arch}/golang:${go}-bookworm # Copy the go.mod and go.sum over and pre-download all the dependencies COPY . /arrow-go -RUN cd /arrow-go && go mod download +RUN cd /arrow-go && go mod download github.com/apache/arrow-go/v18@latest diff --git a/go.mod b/go.mod index 313a5a88..47633cbf 100644 --- a/go.mod +++ b/go.mod @@ -23,9 +23,10 @@ toolchain go1.23.2 require ( github.com/andybalholm/brotli v1.1.1 github.com/apache/thrift v0.21.0 + github.com/cespare/xxhash/v2 v2.3.0 github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 github.com/goccy/go-json v0.10.5 - github.com/golang/snappy v0.0.4 + github.com/golang/snappy v1.0.0 github.com/google/flatbuffers v25.2.10+incompatible github.com/google/uuid v1.6.0 github.com/hamba/avro/v2 v2.28.0 @@ -38,17 +39,17 @@ require ( github.com/pterm/pterm v0.12.80 github.com/stoewer/go-strcase v1.3.0 github.com/stretchr/testify v1.10.0 - github.com/substrait-io/substrait-go/v3 v3.9.0 + github.com/substrait-io/substrait-go/v3 v3.9.1 github.com/tidwall/sjson v1.2.5 github.com/zeebo/xxh3 v1.0.2 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 - golang.org/x/sync v0.11.0 - golang.org/x/sys v0.31.0 - golang.org/x/tools v0.30.0 + golang.org/x/sync v0.13.0 + golang.org/x/sys v0.33.0 + golang.org/x/tools v0.32.0 golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da - gonum.org/v1/gonum v0.15.1 - google.golang.org/grpc v1.71.0 - google.golang.org/protobuf v1.36.5 + gonum.org/v1/gonum v0.16.0 + google.golang.org/grpc v1.72.0 + google.golang.org/protobuf v1.36.6 modernc.org/sqlite v1.29.6 ) @@ -86,11 +87,11 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - golang.org/x/mod v0.23.0 // indirect - golang.org/x/net v0.35.0 // indirect - golang.org/x/term v0.29.0 // indirect - golang.org/x/text v0.22.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect + golang.org/x/mod v0.24.0 // indirect + golang.org/x/net v0.39.0 // indirect + golang.org/x/term v0.31.0 // indirect + golang.org/x/text v0.24.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.41.0 // indirect diff --git a/go.sum b/go.sum index 03470124..fef3d020 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmO github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg= github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc= github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw= @@ -56,8 +58,8 @@ github.com/goccy/go-yaml v1.11.0 h1:n7Z+zx8S9f9KgzG6KtQKf+kwqXZlLNR2F6018Dgau54= github.com/goccy/go-yaml v1.11.0/go.mod h1:H+mJrWtjPTJAHvRbV09MCK9xYwODM+wRTVFFTWckfng= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= -github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= +github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= @@ -162,8 +164,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/substrait-io/substrait v0.66.1-0.20250205013839-a30b3e2d7ec6 h1:XqtxwYFCjS4L0o1QD4ipGHCuFG94U0f6BeldbilGQjU= github.com/substrait-io/substrait v0.66.1-0.20250205013839-a30b3e2d7ec6/go.mod h1:MPFNw6sToJgpD5Z2rj0rQrdP/Oq8HG7Z2t3CAEHtkHw= -github.com/substrait-io/substrait-go/v3 v3.9.0 h1:sRJf0ID9q2TPxJ9eH+oAniepMqt9fYW0Hy32CScT2cI= -github.com/substrait-io/substrait-go/v3 v3.9.0/go.mod h1:VG7jCqtUm28bSngHwq86FywtU74knJ25LNX63SZ53+E= +github.com/substrait-io/substrait-go/v3 v3.9.1 h1:2yfHDHpK6KMcvLd0bJVzUJoeXO+K98yS+ciBruxD9po= +github.com/substrait-io/substrait-go/v3 v3.9.1/go.mod h1:VG7jCqtUm28bSngHwq86FywtU74knJ25LNX63SZ53+E= github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -196,25 +198,25 @@ go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= -golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= -golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= -golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -227,39 +229,39 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= -golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= -golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= +golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU= +golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= -gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0= -gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= -google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= -google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= -google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= -google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a h1:51aaUVRocpvUOSQKM6Q7VuoaktNIaMCLuhZB6DKksq4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a/go.mod h1:uRxBH1mhmO8PGhU89cMcHaXKZqO+OfakD8QQO0oYwlQ= +google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM= +google.golang.org/grpc v1.72.0/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/internal/hashing/xxh3_memo_table.gen.go b/internal/hashing/xxh3_memo_table.gen.go index e99a4f8f..5f105f61 100644 --- a/internal/hashing/xxh3_memo_table.gen.go +++ b/internal/hashing/xxh3_memo_table.gen.go @@ -267,6 +267,11 @@ func (s *Int8MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Int8MemoTable) Exists(val int8) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Int8MemoTable) Get(val interface{}) (int, bool) { @@ -282,10 +287,13 @@ func (s *Int8MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Int8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(int8)) +} - h := hashInt(uint64(val.(int8)), 0) +func (s *Int8MemoTable) InsertOrGet(val int8) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v int8) bool { - return val.(int8) == v + return val == v }) if ok { @@ -293,7 +301,7 @@ func (s *Int8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err e found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(int8), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -544,6 +552,11 @@ func (s *Uint8MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Uint8MemoTable) Exists(val uint8) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Uint8MemoTable) Get(val interface{}) (int, bool) { @@ -559,10 +572,13 @@ func (s *Uint8MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Uint8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(uint8)) +} - h := hashInt(uint64(val.(uint8)), 0) +func (s *Uint8MemoTable) InsertOrGet(val uint8) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v uint8) bool { - return val.(uint8) == v + return val == v }) if ok { @@ -570,7 +586,7 @@ func (s *Uint8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(uint8), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -821,6 +837,11 @@ func (s *Int16MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Int16MemoTable) Exists(val int16) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Int16MemoTable) Get(val interface{}) (int, bool) { @@ -836,10 +857,13 @@ func (s *Int16MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Int16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(int16)) +} - h := hashInt(uint64(val.(int16)), 0) +func (s *Int16MemoTable) InsertOrGet(val int16) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v int16) bool { - return val.(int16) == v + return val == v }) if ok { @@ -847,7 +871,7 @@ func (s *Int16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(int16), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -1098,6 +1122,11 @@ func (s *Uint16MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Uint16MemoTable) Exists(val uint16) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Uint16MemoTable) Get(val interface{}) (int, bool) { @@ -1113,10 +1142,13 @@ func (s *Uint16MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Uint16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(uint16)) +} - h := hashInt(uint64(val.(uint16)), 0) +func (s *Uint16MemoTable) InsertOrGet(val uint16) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v uint16) bool { - return val.(uint16) == v + return val == v }) if ok { @@ -1124,7 +1156,7 @@ func (s *Uint16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(uint16), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -1375,6 +1407,11 @@ func (s *Int32MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Int32MemoTable) Exists(val int32) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Int32MemoTable) Get(val interface{}) (int, bool) { @@ -1390,10 +1427,13 @@ func (s *Int32MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Int32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(int32)) +} - h := hashInt(uint64(val.(int32)), 0) +func (s *Int32MemoTable) InsertOrGet(val int32) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v int32) bool { - return val.(int32) == v + return val == v }) if ok { @@ -1401,7 +1441,7 @@ func (s *Int32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(int32), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -1652,6 +1692,11 @@ func (s *Int64MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Int64MemoTable) Exists(val int64) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Int64MemoTable) Get(val interface{}) (int, bool) { @@ -1667,10 +1712,13 @@ func (s *Int64MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Int64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(int64)) +} - h := hashInt(uint64(val.(int64)), 0) +func (s *Int64MemoTable) InsertOrGet(val int64) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v int64) bool { - return val.(int64) == v + return val == v }) if ok { @@ -1678,7 +1726,7 @@ func (s *Int64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(int64), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -1929,6 +1977,11 @@ func (s *Uint32MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Uint32MemoTable) Exists(val uint32) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Uint32MemoTable) Get(val interface{}) (int, bool) { @@ -1944,10 +1997,13 @@ func (s *Uint32MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Uint32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(uint32)) +} - h := hashInt(uint64(val.(uint32)), 0) +func (s *Uint32MemoTable) InsertOrGet(val uint32) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v uint32) bool { - return val.(uint32) == v + return val == v }) if ok { @@ -1955,7 +2011,7 @@ func (s *Uint32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(uint32), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -2206,6 +2262,11 @@ func (s *Uint64MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Uint64MemoTable) Exists(val uint64) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Uint64MemoTable) Get(val interface{}) (int, bool) { @@ -2221,10 +2282,13 @@ func (s *Uint64MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Uint64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(uint64)) +} - h := hashInt(uint64(val.(uint64)), 0) +func (s *Uint64MemoTable) InsertOrGet(val uint64) (idx int, found bool, err error) { + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v uint64) bool { - return val.(uint64) == v + return val == v }) if ok { @@ -2232,7 +2296,7 @@ func (s *Uint64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(uint64), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -2483,6 +2547,11 @@ func (s *Float32MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Float32MemoTable) Exists(val float32) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Float32MemoTable) Get(val interface{}) (int, bool) { @@ -2508,19 +2577,23 @@ func (s *Float32MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Float32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(float32)) +} + +func (s *Float32MemoTable) InsertOrGet(val float32) (idx int, found bool, err error) { var cmp func(float32) bool - if math.IsNaN(float64(val.(float32))) { + if math.IsNaN(float64(val)) { cmp = isNan32Cmp // use consistent internal bit pattern for NaN regardless of the pattern // that is passed to us. NaN is NaN is NaN val = float32(math.NaN()) } else { - cmp = func(v float32) bool { return val.(float32) == v } + cmp = func(v float32) bool { return val == v } } - h := hashFloat32(val.(float32), 0) + h := hashFloat32(val, 0) e, ok := s.tbl.Lookup(h, cmp) if ok { @@ -2528,7 +2601,7 @@ func (s *Float32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, er found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(float32), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } @@ -2779,6 +2852,11 @@ func (s *Float64MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *Float64MemoTable) Exists(val float64) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *Float64MemoTable) Get(val interface{}) (int, bool) { @@ -2803,18 +2881,22 @@ func (s *Float64MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *Float64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return s.InsertOrGet(val.(float64)) +} + +func (s *Float64MemoTable) InsertOrGet(val float64) (idx int, found bool, err error) { var cmp func(float64) bool - if math.IsNaN(val.(float64)) { + if math.IsNaN(val) { cmp = math.IsNaN // use consistent internal bit pattern for NaN regardless of the pattern // that is passed to us. NaN is NaN is NaN val = math.NaN() } else { - cmp = func(v float64) bool { return val.(float64) == v } + cmp = func(v float64) bool { return val == v } } - h := hashFloat64(val.(float64), 0) + h := hashFloat64(val, 0) e, ok := s.tbl.Lookup(h, cmp) if ok { @@ -2822,7 +2904,7 @@ func (s *Float64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, er found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.(float64), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } diff --git a/internal/hashing/xxh3_memo_table.gen.go.tmpl b/internal/hashing/xxh3_memo_table.gen.go.tmpl index 9ba35c72..14a8f212 100644 --- a/internal/hashing/xxh3_memo_table.gen.go.tmpl +++ b/internal/hashing/xxh3_memo_table.gen.go.tmpl @@ -267,6 +267,11 @@ func (s *{{.Name}}MemoTable) WriteOutSubsetLE(start int, out []byte) { s.tbl.WriteOutSubset(start, out) } +func (s *{{.Name}}MemoTable) Exists(val {{.name}}) bool { + _, ok := s.Get(val) + return ok +} + // Get returns the index of the requested value in the hash table or KeyNotFound // along with a boolean indicating if it was found or not. func (s *{{.Name}}MemoTable) Get(val interface{}) (int, bool) { @@ -304,31 +309,35 @@ func (s *{{.Name}}MemoTable) Get(val interface{}) (int, bool) { // value into the table and return the new index. found indicates whether or not it already // existed in the table (true) or was inserted by this call (false). func (s *{{.Name}}MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { - {{if and (ne .Name "Float32") (ne .Name "Float64") }} - h := hashInt(uint64(val.({{.name}})), 0) + return s.InsertOrGet(val.({{.name}})) +} + +func (s *{{.Name}}MemoTable) InsertOrGet(val {{.name}}) (idx int, found bool, err error) { + {{if and (ne .Name "Float32") (ne .Name "Float64") -}} + h := hashInt(uint64(val), 0) e, ok := s.tbl.Lookup(h, func(v {{.name}}) bool { - return val.({{.name}}) == v + return val == v }) {{ else }} var cmp func({{.name}}) bool {{if eq .Name "Float32"}} - if math.IsNaN(float64(val.(float32))) { + if math.IsNaN(float64(val)) { cmp = isNan32Cmp // use consistent internal bit pattern for NaN regardless of the pattern // that is passed to us. NaN is NaN is NaN val = float32(math.NaN()) {{ else -}} - if math.IsNaN(val.(float64)) { + if math.IsNaN(val) { cmp = math.IsNaN // use consistent internal bit pattern for NaN regardless of the pattern // that is passed to us. NaN is NaN is NaN val = math.NaN() {{end -}} } else { - cmp = func(v {{.name}}) bool { return val.({{.name}}) == v } + cmp = func(v {{.name}}) bool { return val == v } } - h := hash{{.Name}}(val.({{.name}}), 0) + h := hash{{.Name}}(val, 0) e, ok := s.tbl.Lookup(h, cmp) {{ end }} if ok { @@ -336,7 +345,7 @@ func (s *{{.Name}}MemoTable) GetOrInsert(val interface{}) (idx int, found bool, found = true } else { idx = s.Size() - s.tbl.Insert(e, h, val.({{.name}}), int32(idx)) + s.tbl.Insert(e, h, val, int32(idx)) } return } diff --git a/internal/hashing/xxh3_memo_table.go b/internal/hashing/xxh3_memo_table.go index fbb8b335..f10a9b21 100644 --- a/internal/hashing/xxh3_memo_table.go +++ b/internal/hashing/xxh3_memo_table.go @@ -74,6 +74,18 @@ type MemoTable interface { WriteOutSubset(offset int, out []byte) } +type MemoTypes interface { + int8 | int16 | int32 | int64 | + uint8 | uint16 | uint32 | uint64 | + float32 | float64 | []byte +} + +type TypedMemoTable[T MemoTypes] interface { + MemoTable + Exists(T) bool + InsertOrGet(val T) (idx int, found bool, err error) +} + type NumericMemoTable interface { MemoTable WriteOutLE(out []byte) @@ -202,25 +214,17 @@ func (BinaryMemoTable) getHash(val interface{}) uint64 { } } -// helper function to append the given value to the builder regardless -// of the underlying binary type. -func (b *BinaryMemoTable) appendVal(val interface{}) { - switch v := val.(type) { - case string: - b.builder.AppendString(v) - case []byte: - b.builder.Append(v) - case ByteSlice: - b.builder.Append(v.Bytes()) - } -} - func (b *BinaryMemoTable) lookup(h uint64, val []byte) (*entryInt32, bool) { return b.tbl.Lookup(h, func(i int32) bool { return bytes.Equal(val, b.builder.Value(int(i))) }) } +func (b *BinaryMemoTable) Exists(val []byte) bool { + _, ok := b.lookup(b.getHash(val), val) + return ok +} + // Get returns the index of the specified value in the table or KeyNotFound, // and a boolean indicating whether it was found in the table. func (b *BinaryMemoTable) Get(val interface{}) (int, bool) { @@ -246,17 +250,21 @@ func (b *BinaryMemoTable) GetOrInsertBytes(val []byte) (idx int, found bool, err return } +func (b *BinaryMemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { + return b.InsertOrGet(b.valAsByteSlice(val)) +} + // GetOrInsert returns the index of the given value in the table, if not found // it is inserted into the table. The return value 'found' indicates whether the value // was found in the table (true) or inserted (false) along with any possible error. -func (b *BinaryMemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) { +func (b *BinaryMemoTable) InsertOrGet(val []byte) (idx int, found bool, err error) { h := b.getHash(val) - p, found := b.lookup(h, b.valAsByteSlice(val)) + p, found := b.lookup(h, val) if found { idx = int(p.payload.val) } else { idx = b.Size() - b.appendVal(val) + b.builder.Append(val) b.tbl.Insert(p, h, int32(idx), -1) } return diff --git a/parquet/compress/compress.go b/parquet/compress/compress.go index 72768e01..dfa07458 100644 --- a/parquet/compress/compress.go +++ b/parquet/compress/compress.go @@ -33,6 +33,14 @@ func (c Compression) String() string { return parquet.CompressionCodec(c).String() } +func (c Compression) MarshalText() ([]byte, error) { + return parquet.CompressionCodec(c).MarshalText() +} + +func (c *Compression) UnmarshalText(text []byte) error { + return (*parquet.CompressionCodec)(c).UnmarshalText(text) +} + // DefaultCompressionLevel will use flate.DefaultCompression since many of the compression libraries // use that to denote "use the default". const DefaultCompressionLevel = flate.DefaultCompression diff --git a/parquet/compress/compress_test.go b/parquet/compress/compress_test.go index e6c135d9..00f13760 100644 --- a/parquet/compress/compress_test.go +++ b/parquet/compress/compress_test.go @@ -138,3 +138,44 @@ func TestCompressReaderWriter(t *testing.T) { }) } } + +var marshalTests = []struct { + text string + codec compress.Compression +}{ + {"UNCOMPRESSED", compress.Codecs.Uncompressed}, + {"SNAPPY", compress.Codecs.Snappy}, + {"GZIP", compress.Codecs.Gzip}, + {"LZO", compress.Codecs.Lzo}, + {"BROTLI", compress.Codecs.Brotli}, + {"LZ4", compress.Codecs.Lz4}, + {"ZSTD", compress.Codecs.Zstd}, + {"LZ4_RAW", compress.Codecs.Lz4Raw}, +} + +func TestMarshalText(t *testing.T) { + for _, tt := range marshalTests { + t.Run(tt.text, func(t *testing.T) { + data, err := tt.codec.MarshalText() + assert.NoError(t, err) + assert.Equal(t, tt.text, string(data)) + }) + } +} + +func TestUnmarshalText(t *testing.T) { + for _, tt := range marshalTests { + t.Run(tt.text, func(t *testing.T) { + var compression compress.Compression + err := compression.UnmarshalText([]byte(tt.text)) + assert.NoError(t, err) + assert.Equal(t, tt.codec, compression) + }) + } +} + +func TestUnmarshalTextError(t *testing.T) { + var compression compress.Compression + err := compression.UnmarshalText([]byte("NO SUCH CODEC")) + assert.EqualError(t, err, "not a valid CompressionCodec string") +} diff --git a/parquet/file/column_reader.go b/parquet/file/column_reader.go index 5faf8bc0..03ca5a8f 100644 --- a/parquet/file/column_reader.go +++ b/parquet/file/column_reader.go @@ -223,6 +223,7 @@ func (c *columnChunkReader) pager() PageReader { return c.rdr } func (c *columnChunkReader) setPageReader(rdr PageReader) { c.rdr, c.err = rdr, nil c.decoders = make(map[format.Encoding]encoding.TypedDecoder) + c.newDictionary = false c.numBuffered, c.numDecoded = 0, 0 } diff --git a/parquet/file/column_writer.go b/parquet/file/column_writer.go index f608cd0f..8d35aaa3 100644 --- a/parquet/file/column_writer.go +++ b/parquet/file/column_writer.go @@ -20,12 +20,14 @@ import ( "bytes" "encoding/binary" "io" + "strconv" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/bitutil" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/apache/arrow-go/v18/parquet" + "github.com/apache/arrow-go/v18/parquet/internal/debug" "github.com/apache/arrow-go/v18/parquet/internal/encoding" "github.com/apache/arrow-go/v18/parquet/metadata" "github.com/apache/arrow-go/v18/parquet/schema" @@ -76,6 +78,8 @@ type ColumnChunkWriter interface { LevelInfo() LevelInfo SetBitsBuffer(*memory.Buffer) HasBitsBuffer() bool + + GetBloomFilter() metadata.BloomFilterBuilder } func computeLevelInfo(descr *schema.Column) (info LevelInfo) { @@ -111,6 +115,7 @@ type columnWriter struct { pageStatistics metadata.TypedStatistics chunkStatistics metadata.TypedStatistics + bloomFilter metadata.BloomFilterBuilder // total number of values stored in the current data page. this is the maximum // of the number of encoded def levels or encoded values. for @@ -172,6 +177,7 @@ func newColumnWriterBase(metaData *metadata.ColumnChunkMetaDataBuilder, pager Pa ret.reset() + ret.initBloomFilter() return ret } @@ -416,9 +422,12 @@ func (w *columnWriter) FlushBufferedDataPages() (err error) { } } - for _, p := range w.pages { + for i, p := range w.pages { defer p.Release() if err = w.WriteDataPage(p); err != nil { + // To keep pages in consistent state, + // remove the pages that will be released using above defer call. + w.pages = w.pages[i+1:] return err } } @@ -428,10 +437,14 @@ func (w *columnWriter) FlushBufferedDataPages() (err error) { func (w *columnWriter) writeLevels(numValues int64, defLevels, repLevels []int16) int64 { toWrite := int64(0) + maxDefLevel := w.descr.MaxDefinitionLevel() + // if the field is required and non-repeated, no definition levels - if defLevels != nil && w.descr.MaxDefinitionLevel() > 0 { + if defLevels != nil && maxDefLevel > 0 { for _, v := range defLevels[:numValues] { - if v == w.descr.MaxDefinitionLevel() { + debug.Assert(v <= maxDefLevel, "columnwriter: invalid definition level "+ + strconv.Itoa(int(v))+" for column "+w.descr.Path()) + if v == maxDefLevel { toWrite++ } } @@ -696,3 +709,29 @@ func (w *columnWriter) maybeReplaceValidity(values arrow.Array, newNullCount int defer data.Release() return array.MakeFromData(data) } + +func (w *columnWriter) initBloomFilter() { + path := w.descr.Path() + if !w.props.BloomFilterEnabledFor(path) { + return + } + + maxFilterBytes := w.props.MaxBloomFilterBytes() + ndv := w.props.BloomFilterNDVFor(path) + fpp := w.props.BloomFilterFPPFor(path) + // if user specified the column NDV, we can construct the bloom filter for it + if ndv > 0 { + w.bloomFilter = metadata.NewBloomFilterFromNDVAndFPP(uint32(ndv), fpp, maxFilterBytes, w.mem) + } else if w.props.AdaptiveBloomFilterEnabledFor(path) { + numCandidates := w.props.BloomFilterCandidatesFor(path) + // construct adaptive bloom filter writer + w.bloomFilter = metadata.NewAdaptiveBlockSplitBloomFilter(uint32(maxFilterBytes), numCandidates, fpp, w.descr, w.mem) + } else { + // construct a bloom filter using the max size + w.bloomFilter = metadata.NewBloomFilter(uint32(maxFilterBytes), uint32(maxFilterBytes), w.mem) + } +} + +func (w *columnWriter) GetBloomFilter() metadata.BloomFilterBuilder { + return w.bloomFilter +} diff --git a/parquet/file/column_writer_test.go b/parquet/file/column_writer_test.go index 0f245c87..90b239e4 100644 --- a/parquet/file/column_writer_test.go +++ b/parquet/file/column_writer_test.go @@ -18,6 +18,7 @@ package file_test import ( "bytes" + "errors" "math" "reflect" "runtime" @@ -792,3 +793,42 @@ func TestDictionaryReslice(t *testing.T) { }) } } + +// TestWriteDataFailure asserts that if WriteDataPage fails, the internal state of the +// ColumnChunkWriter is still valid and accessing function does not panic. +func TestWriteDataFailure(t *testing.T) { + sc := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Required, schema.FieldList{ + schema.Must(schema.ListOf( + schema.Must(schema.NewPrimitiveNode("column", parquet.Repetitions.Optional, parquet.Types.Int32, -1, -1)), + parquet.Repetitions.Optional, -1)), + }, -1))) + descr := sc.Column(0) + props := parquet.NewWriterProperties( + parquet.WithStats(true), + parquet.WithVersion(parquet.V1_0), + parquet.WithDataPageVersion(parquet.DataPageV1), + parquet.WithDictionaryDefault(true)) // true to enable buffering pages. + + metadata := metadata.NewColumnChunkMetaDataBuilder(props, descr) + pager := new(mockpagewriter) + defer pager.AssertExpectations(t) + pager.On("HasCompressor").Return(false) + wr := file.NewColumnChunkWriter(metadata, pager, props).(*file.Int32ColumnChunkWriter) + + // Write some valid data. + wr.WriteBatch([]int32{0, 1, 2, 3}, + []int16{3, 3, 0, 3, 2, 3}, + []int16{0, 1, 0, 0, 1, 1}) + + // Simulate WriteDataPage failure. + failureErr := errors.New("mock error from WriteDataPage") + pager.On("WriteDataPage", mock.MatchedBy(func(page file.DataPage) bool { + _, ok := page.(*file.DataPageV1) + return ok + })).Return(0, failureErr) + + // Expect error from FlushBufferedDataPages but it should leave internal fields in valid state. + err := wr.FlushBufferedDataPages() + assert.Equal(t, err, failureErr) + assert.Equal(t, int64(0), wr.TotalBytesWritten()) +} diff --git a/parquet/file/column_writer_types.gen.go b/parquet/file/column_writer_types.gen.go index cb022337..65bf29a7 100644 --- a/parquet/file/column_writer_types.gen.go +++ b/parquet/file/column_writer_types.gen.go @@ -184,6 +184,10 @@ func (w *Int32ColumnChunkWriter) writeValues(values []int32, numNulls int64) { if w.pageStatistics != nil { w.pageStatistics.(*metadata.Int32Statistics).Update(values, numNulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *Int32ColumnChunkWriter) writeValuesSpaced(spacedValues []int32, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -196,6 +200,10 @@ func (w *Int32ColumnChunkWriter) writeValuesSpaced(spacedValues []int32, numRead nulls := numValues - numRead w.pageStatistics.(*metadata.Int32Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *Int32ColumnChunkWriter) checkDictionarySizeLimit() { @@ -374,6 +382,10 @@ func (w *Int64ColumnChunkWriter) writeValues(values []int64, numNulls int64) { if w.pageStatistics != nil { w.pageStatistics.(*metadata.Int64Statistics).Update(values, numNulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *Int64ColumnChunkWriter) writeValuesSpaced(spacedValues []int64, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -386,6 +398,10 @@ func (w *Int64ColumnChunkWriter) writeValuesSpaced(spacedValues []int64, numRead nulls := numValues - numRead w.pageStatistics.(*metadata.Int64Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *Int64ColumnChunkWriter) checkDictionarySizeLimit() { @@ -564,6 +580,10 @@ func (w *Int96ColumnChunkWriter) writeValues(values []parquet.Int96, numNulls in if w.pageStatistics != nil { w.pageStatistics.(*metadata.Int96Statistics).Update(values, numNulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *Int96ColumnChunkWriter) writeValuesSpaced(spacedValues []parquet.Int96, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -576,6 +596,10 @@ func (w *Int96ColumnChunkWriter) writeValuesSpaced(spacedValues []parquet.Int96, nulls := numValues - numRead w.pageStatistics.(*metadata.Int96Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *Int96ColumnChunkWriter) checkDictionarySizeLimit() { @@ -754,6 +778,10 @@ func (w *Float32ColumnChunkWriter) writeValues(values []float32, numNulls int64) if w.pageStatistics != nil { w.pageStatistics.(*metadata.Float32Statistics).Update(values, numNulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *Float32ColumnChunkWriter) writeValuesSpaced(spacedValues []float32, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -766,6 +794,10 @@ func (w *Float32ColumnChunkWriter) writeValuesSpaced(spacedValues []float32, num nulls := numValues - numRead w.pageStatistics.(*metadata.Float32Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *Float32ColumnChunkWriter) checkDictionarySizeLimit() { @@ -944,6 +976,10 @@ func (w *Float64ColumnChunkWriter) writeValues(values []float64, numNulls int64) if w.pageStatistics != nil { w.pageStatistics.(*metadata.Float64Statistics).Update(values, numNulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *Float64ColumnChunkWriter) writeValuesSpaced(spacedValues []float64, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -956,6 +992,10 @@ func (w *Float64ColumnChunkWriter) writeValuesSpaced(spacedValues []float64, num nulls := numValues - numRead w.pageStatistics.(*metadata.Float64Statistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *Float64ColumnChunkWriter) checkDictionarySizeLimit() { @@ -1137,6 +1177,10 @@ func (w *BooleanColumnChunkWriter) writeValues(values []bool, numNulls int64) { if w.pageStatistics != nil { w.pageStatistics.(*metadata.BooleanStatistics).Update(values, numNulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *BooleanColumnChunkWriter) writeValuesSpaced(spacedValues []bool, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -1149,6 +1193,10 @@ func (w *BooleanColumnChunkWriter) writeValuesSpaced(spacedValues []bool, numRea nulls := numValues - numRead w.pageStatistics.(*metadata.BooleanStatistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *BooleanColumnChunkWriter) checkDictionarySizeLimit() { @@ -1327,6 +1375,10 @@ func (w *ByteArrayColumnChunkWriter) writeValues(values []parquet.ByteArray, num if w.pageStatistics != nil { w.pageStatistics.(*metadata.ByteArrayStatistics).Update(values, numNulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *ByteArrayColumnChunkWriter) writeValuesSpaced(spacedValues []parquet.ByteArray, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -1339,6 +1391,10 @@ func (w *ByteArrayColumnChunkWriter) writeValuesSpaced(spacedValues []parquet.By nulls := numValues - numRead w.pageStatistics.(*metadata.ByteArrayStatistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *ByteArrayColumnChunkWriter) checkDictionarySizeLimit() { @@ -1521,6 +1577,10 @@ func (w *FixedLenByteArrayColumnChunkWriter) writeValues(values []parquet.FixedL w.pageStatistics.(*metadata.FixedLenByteArrayStatistics).Update(values, numNulls) } } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *FixedLenByteArrayColumnChunkWriter) writeValuesSpaced(spacedValues []parquet.FixedLenByteArray, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -1537,6 +1597,10 @@ func (w *FixedLenByteArrayColumnChunkWriter) writeValuesSpaced(spacedValues []pa w.pageStatistics.(*metadata.FixedLenByteArrayStatistics).UpdateSpaced(spacedValues, validBits, validBitsOffset, nulls) } } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *FixedLenByteArrayColumnChunkWriter) checkDictionarySizeLimit() { diff --git a/parquet/file/column_writer_types.gen.go.tmpl b/parquet/file/column_writer_types.gen.go.tmpl index 772777b4..d0a4da26 100644 --- a/parquet/file/column_writer_types.gen.go.tmpl +++ b/parquet/file/column_writer_types.gen.go.tmpl @@ -197,6 +197,10 @@ func (w *{{.Name}}ColumnChunkWriter) writeValues(values []{{.name}}, numNulls in } {{- end}} } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetHashes(w.bloomFilter.Hasher(), values)) + } } func (w *{{.Name}}ColumnChunkWriter) writeValuesSpaced(spacedValues []{{.name}}, numRead, numValues int64, validBits []byte, validBitsOffset int64) { @@ -217,6 +221,10 @@ func (w *{{.Name}}ColumnChunkWriter) writeValuesSpaced(spacedValues []{{.name}}, } {{- end}} } + if w.bloomFilter != nil { + // TODO: optimize for Dictionary Encoding case + w.bloomFilter.InsertBulk(metadata.GetSpacedHashes(w.bloomFilter.Hasher(), numRead, spacedValues, validBits, validBitsOffset)) + } } func (w *{{.Name}}ColumnChunkWriter) checkDictionarySizeLimit() { diff --git a/parquet/file/file_reader.go b/parquet/file/file_reader.go index ddccd794..4025939c 100644 --- a/parquet/file/file_reader.go +++ b/parquet/file/file_reader.go @@ -44,11 +44,12 @@ var ( // Reader is the main interface for reading a parquet file type Reader struct { - r parquet.ReaderAtSeeker - props *parquet.ReaderProperties - metadata *metadata.FileMetaData - fileDecryptor encryption.FileDecryptor - pageIndexReader *metadata.PageIndexReader + r parquet.ReaderAtSeeker + props *parquet.ReaderProperties + metadata *metadata.FileMetaData + fileDecryptor encryption.FileDecryptor + pageIndexReader *metadata.PageIndexReader + bloomFilterReader *metadata.BloomFilterReader bufferPool sync.Pool } @@ -321,9 +322,27 @@ func (f *Reader) RowGroup(i int) *RowGroupReader { fileDecryptor: f.fileDecryptor, bufferPool: &f.bufferPool, pageIndexReader: f.pageIndexReader, + // don't pre-emptively initialize the row group page index reader + // do it on demand, but ensure that it is goroutine safe. + rgPageIndexReader: sync.OnceValues(func() (*metadata.RowGroupPageIndexReader, error) { + return f.pageIndexReader.RowGroup(i) + }), } } func (f *Reader) GetPageIndexReader() *metadata.PageIndexReader { return f.pageIndexReader } + +func (f *Reader) GetBloomFilterReader() *metadata.BloomFilterReader { + if f.bloomFilterReader == nil { + f.bloomFilterReader = &metadata.BloomFilterReader{ + Input: f.r, + FileMetadata: f.metadata, + Props: f.props, + FileDecryptor: f.fileDecryptor, + BufferPool: &f.bufferPool, + } + } + return f.bloomFilterReader +} diff --git a/parquet/file/file_writer.go b/parquet/file/file_writer.go index f6168311..fa8d5db6 100644 --- a/parquet/file/file_writer.go +++ b/parquet/file/file_writer.go @@ -40,6 +40,7 @@ type Writer struct { fileEncryptor encryption.FileEncryptor rowGroupWriter *rowGroupWriter pageIndexBuilder *metadata.PageIndexBuilder + bloomFilters *metadata.FileBloomFilterBuilder // The Schema of this writer Schema *schema.Schema @@ -135,6 +136,7 @@ func (fw *Writer) appendRowGroup(buffered bool) *rowGroupWriter { fw.rowGroupWriter = newRowGroupWriter(fw.sink, rgMeta, int16(fw.rowGroups)-1, fw.props, buffered, fw.fileEncryptor, fw.pageIndexBuilder) + fw.bloomFilters.AppendRowGroup(rgMeta, fw.rowGroupWriter.bloomFilters) return fw.rowGroupWriter } @@ -158,10 +160,17 @@ func (fw *Writer) startFile() { } fw.fileEncryptor = encryption.NewFileEncryptor(encryptionProps, fw.props.Allocator()) + fw.metadata.SetFileEncryptor(fw.fileEncryptor) if encryptionProps.EncryptedFooter() { magic = magicEBytes } } + + fw.bloomFilters = &metadata.FileBloomFilterBuilder{ + Schema: fw.Schema, + Encryptor: fw.fileEncryptor, + } + n, err := fw.sink.Write(magic) if n != 4 || err != nil { panic("failed to write magic number") @@ -213,11 +222,17 @@ func (fw *Writer) Close() (err error) { }() err = fw.FlushWithFooter() - fw.metadata.Clear() } return nil } +// FileMetadata returns the current state of the FileMetadata that would be written +// if this file were to be closed. If the file has already been closed, then this +// will return the FileMetaData which was written to the file. +func (fw *Writer) FileMetadata() (*metadata.FileMetaData, error) { + return fw.metadata.Snapshot() +} + // FlushWithFooter closes any open row group writer and writes the file footer, leaving // the writer open for additional row groups. Additional footers written by later // calls to FlushWithFooter or Close will be cumulative, so that only the last footer @@ -229,6 +244,9 @@ func (fw *Writer) FlushWithFooter() error { fw.rowGroupWriter.Close() } fw.rowGroupWriter = nil + if err := fw.bloomFilters.WriteTo(fw.sink); err != nil { + return err + } fw.writePageIndex() diff --git a/parquet/file/file_writer_test.go b/parquet/file/file_writer_test.go index 7bcbe296..ec1f7eb5 100644 --- a/parquet/file/file_writer_test.go +++ b/parquet/file/file_writer_test.go @@ -1067,3 +1067,92 @@ func TestPageIndexRoundTripSuite(t *testing.T) { suite.Run(t, &PageIndexRoundTripSuite{pageVersion: parquet.DataPageV2}) }) } + +func TestWriteBloomFilters(t *testing.T) { + input1 := []parquet.ByteArray{ + parquet.ByteArray("hello"), + parquet.ByteArray("world"), + parquet.ByteArray("hello"), + parquet.ByteArray("parquet"), + } + + input2 := []parquet.ByteArray{ + parquet.ByteArray("foo"), + parquet.ByteArray("bar"), + parquet.ByteArray("baz"), + parquet.ByteArray("columns"), + } + + size := len(input1) + chunk := size / 2 + + props := parquet.NewWriterProperties( + parquet.WithDictionaryDefault(false), + parquet.WithBloomFilterEnabledFor("col1", true), + parquet.WithBatchSize(int64(chunk)), + ) + + field1, err := schema.NewPrimitiveNode("col1", parquet.Repetitions.Required, + parquet.Types.ByteArray, -1, -1) + require.NoError(t, err) + field2, err := schema.NewPrimitiveNode("col2", parquet.Repetitions.Required, + parquet.Types.ByteArray, -1, -1) + require.NoError(t, err) + sc, err := schema.NewGroupNode("test", parquet.Repetitions.Required, + schema.FieldList{field1, field2}, -1) + require.NoError(t, err) + + sink := encoding.NewBufferWriter(0, memory.DefaultAllocator) + writer := file.NewParquetWriter(sink, sc, file.WithWriterProps(props)) + + rgw := writer.AppendRowGroup() + cwr, err := rgw.NextColumn() + require.NoError(t, err) + + cw, ok := cwr.(*file.ByteArrayColumnChunkWriter) + require.True(t, ok) + + nVals, err := cw.WriteBatch(input1[:chunk], nil, nil) + require.NoError(t, err) + require.EqualValues(t, chunk, nVals) + + nVals, err = cw.WriteBatch(input1[chunk:], nil, nil) + require.NoError(t, err) + require.EqualValues(t, chunk, nVals) + + cwr, err = rgw.NextColumn() + require.NoError(t, err) + cw, ok = cwr.(*file.ByteArrayColumnChunkWriter) + require.True(t, ok) + + nVals, err = cw.WriteBatch(input2, nil, nil) + require.NoError(t, err) + require.EqualValues(t, size, nVals) + + require.NoError(t, cwr.Close()) + require.NoError(t, rgw.Close()) + require.NoError(t, writer.Close()) + + rdr, err := file.NewParquetReader(bytes.NewReader(sink.Bytes())) + require.NoError(t, err) + + bloom := rdr.GetBloomFilterReader() + bloomRgr, err := bloom.RowGroup(0) + require.NoError(t, err) + + filter, err := bloomRgr.GetColumnBloomFilter(1) + require.NoError(t, err) + require.Nil(t, filter) // no filter written for col2 + + filter, err = bloomRgr.GetColumnBloomFilter(0) + require.NoError(t, err) + require.NotNil(t, filter) + + byteArrayFilter := metadata.TypedBloomFilter[parquet.ByteArray]{BloomFilter: filter} + assert.True(t, byteArrayFilter.Check(parquet.ByteArray("hello"))) + assert.True(t, byteArrayFilter.Check(parquet.ByteArray("world"))) + assert.True(t, byteArrayFilter.Check(parquet.ByteArray("parquet"))) + assert.False(t, byteArrayFilter.Check(parquet.ByteArray("foo"))) + assert.False(t, byteArrayFilter.Check(parquet.ByteArray("bar"))) + assert.False(t, byteArrayFilter.Check(parquet.ByteArray("baz"))) +} diff --git a/parquet/file/page_writer.go b/parquet/file/page_writer.go index 17182921..8d22ef83 100644 --- a/parquet/file/page_writer.go +++ b/parquet/file/page_writer.go @@ -220,7 +220,7 @@ func (pw *serializedPageWriter) Close(hasDict, fallback bool) error { DataEncodingStats: pw.dataEncodingStats, } pw.FinishPageIndexes(0) - pw.metaData.Finish(chunkInfo, hasDict, fallback, encodingStats, pw.metaEncryptor) + pw.metaData.Finish(chunkInfo, hasDict, fallback, encodingStats) _, err := pw.metaData.WriteTo(pw.sink) return err } @@ -505,7 +505,7 @@ func (bw *bufferedPageWriter) Close(hasDict, fallback bool) error { DictEncodingStats: bw.pager.dictEncodingStats, DataEncodingStats: bw.pager.dataEncodingStats, } - bw.metadata.Finish(chunkInfo, hasDict, fallback, encodingStats, bw.pager.metaEncryptor) + bw.metadata.Finish(chunkInfo, hasDict, fallback, encodingStats) bw.pager.FinishPageIndexes(position) bw.metadata.WriteTo(bw.inMemSink) diff --git a/parquet/file/record_reader.go b/parquet/file/record_reader.go index 20d68f47..e2fdcc85 100644 --- a/parquet/file/record_reader.go +++ b/parquet/file/record_reader.go @@ -63,6 +63,7 @@ type RecordReader interface { // ReleaseValues transfers the buffer of data with the values to the caller, // a new buffer will be allocated on subsequent calls. ReleaseValues() *memory.Buffer + ResetValues() // NullCount returns the number of nulls decoded NullCount() int64 // Type returns the parquet physical type of the column @@ -78,6 +79,10 @@ type RecordReader interface { // Release decrements the ref count by one, releasing the internal buffers when // the ref count is 0. Release() + // SeekToRow will shift the record reader so that subsequent reads will + // start at the desired row. It will utilize Offset Indexes if they exist + // to skip pages and seek. + SeekToRow(int64) error } // BinaryRecordReader provides an extra GetBuilderChunks function above and beyond @@ -125,12 +130,12 @@ type primitiveRecordReader struct { validBits *memory.Buffer mem memory.Allocator - refCount int64 + refCount atomic.Int64 useValues bool } -func createPrimitiveRecordReader(descr *schema.Column, mem memory.Allocator, bufferPool *sync.Pool) primitiveRecordReader { - return primitiveRecordReader{ +func createPrimitiveRecordReader(descr *schema.Column, mem memory.Allocator, bufferPool *sync.Pool) *primitiveRecordReader { + prr := &primitiveRecordReader{ ColumnChunkReader: newTypedColumnChunkReader(columnChunkReader{ descr: descr, mem: mem, @@ -139,17 +144,19 @@ func createPrimitiveRecordReader(descr *schema.Column, mem memory.Allocator, buf values: memory.NewResizableBuffer(mem), validBits: memory.NewResizableBuffer(mem), mem: mem, - refCount: 1, useValues: descr.PhysicalType() != parquet.Types.ByteArray && descr.PhysicalType() != parquet.Types.FixedLenByteArray, } + + prr.refCount.Add(1) + return prr } func (pr *primitiveRecordReader) Retain() { - atomic.AddInt64(&pr.refCount, 1) + pr.refCount.Add(1) } func (pr *primitiveRecordReader) Release() { - if atomic.AddInt64(&pr.refCount, -1) == 0 { + if pr.refCount.Add(-1) == 0 { if pr.values != nil { pr.values.Release() pr.values = nil @@ -320,7 +327,7 @@ type recordReader struct { defLevels *memory.Buffer repLevels *memory.Buffer - refCount int64 + refCount atomic.Int64 } // binaryRecordReader is the recordReaderImpl for non-primitive data @@ -341,22 +348,22 @@ func newRecordReader(descr *schema.Column, info LevelInfo, mem memory.Allocator, mem = memory.DefaultAllocator } - pr := createPrimitiveRecordReader(descr, mem, bufferPool) - return &recordReader{ - refCount: 1, - recordReaderImpl: &pr, + rr := &recordReader{ + recordReaderImpl: createPrimitiveRecordReader(descr, mem, bufferPool), leafInfo: info, defLevels: memory.NewResizableBuffer(mem), repLevels: memory.NewResizableBuffer(mem), } + rr.refCount.Add(1) + return rr } func (rr *recordReader) Retain() { - atomic.AddInt64(&rr.refCount, 1) + rr.refCount.Add(1) } func (rr *recordReader) Release() { - if atomic.AddInt64(&rr.refCount, -1) == 0 { + if rr.refCount.Add(-1) == 0 { rr.recordReaderImpl.Release() rr.defLevels.Release() rr.repLevels.Release() @@ -440,12 +447,27 @@ func (rr *recordReader) reserveValues(extra int64) error { return rr.recordReaderImpl.ReserveValues(extra, rr.leafInfo.HasNullableValues()) } -func (rr *recordReader) resetValues() { +func (rr *recordReader) ResetValues() { rr.recordReaderImpl.ResetValues() } +func (rr *recordReader) SeekToRow(recordIdx int64) error { + if err := rr.recordReaderImpl.SeekToRow(recordIdx); err != nil { + return err + } + + rr.atRecStart = true + rr.recordsRead = 0 + // force re-reading the definition/repetition levels + // calling SeekToRow on the underlying column reader will ensure that + // the next reads will pull from the correct row + rr.levelsPos, rr.levelsWritten = 0, 0 + + return nil +} + func (rr *recordReader) Reset() { - rr.resetValues() + rr.ResetValues() if rr.levelsWritten > 0 { remain := int(rr.levelsWritten - rr.levelsPos) @@ -741,17 +763,18 @@ func newFLBARecordReader(descr *schema.Column, info LevelInfo, mem memory.Alloca byteWidth := descr.TypeLength() - return &binaryRecordReader{&recordReader{ + brr := &binaryRecordReader{&recordReader{ recordReaderImpl: &flbaRecordReader{ - createPrimitiveRecordReader(descr, mem, bufferPool), + *createPrimitiveRecordReader(descr, mem, bufferPool), array.NewFixedSizeBinaryBuilder(mem, &arrow.FixedSizeBinaryType{ByteWidth: byteWidth}), nil, }, leafInfo: info, defLevels: memory.NewResizableBuffer(mem), repLevels: memory.NewResizableBuffer(mem), - refCount: 1, }} + brr.refCount.Add(1) + return brr } // byteArrayRecordReader is the specialization impl for byte-array columns @@ -773,17 +796,18 @@ func newByteArrayRecordReader(descr *schema.Column, info LevelInfo, dtype arrow. dt = arrow.BinaryTypes.Binary } - return &binaryRecordReader{&recordReader{ + brr := &binaryRecordReader{&recordReader{ recordReaderImpl: &byteArrayRecordReader{ - createPrimitiveRecordReader(descr, mem, bufferPool), + *createPrimitiveRecordReader(descr, mem, bufferPool), array.NewBinaryBuilder(mem, dt), nil, }, leafInfo: info, defLevels: memory.NewResizableBuffer(mem), repLevels: memory.NewResizableBuffer(mem), - refCount: 1, }} + brr.refCount.Add(1) + return brr } func (br *byteArrayRecordReader) ReserveValues(extra int64, hasNullable bool) error { @@ -893,10 +917,10 @@ func newByteArrayDictRecordReader(descr *schema.Column, info LevelInfo, dtype ar dt.ValueType = arrow.BinaryTypes.Binary } - return &binaryRecordReader{&recordReader{ + brr := &binaryRecordReader{&recordReader{ recordReaderImpl: &byteArrayDictRecordReader{ byteArrayRecordReader: byteArrayRecordReader{ - createPrimitiveRecordReader(descr, mem, bufferPool), + *createPrimitiveRecordReader(descr, mem, bufferPool), array.NewDictionaryBuilder(mem, dt), nil, }, @@ -905,8 +929,10 @@ func newByteArrayDictRecordReader(descr *schema.Column, info LevelInfo, dtype ar leafInfo: info, defLevels: memory.NewResizableBuffer(mem), repLevels: memory.NewResizableBuffer(mem), - refCount: 1, }} + + brr.refCount.Add(1) + return brr } func (bd *byteArrayDictRecordReader) GetBuilderChunks() []arrow.Array { diff --git a/parquet/file/row_group_reader.go b/parquet/file/row_group_reader.go index acfac0ea..ea5f7098 100644 --- a/parquet/file/row_group_reader.go +++ b/parquet/file/row_group_reader.go @@ -42,7 +42,7 @@ type RowGroupReader struct { fileDecryptor encryption.FileDecryptor pageIndexReader *metadata.PageIndexReader - rgPageIndexReader *metadata.RowGroupPageIndexReader + rgPageIndexReader func() (*metadata.RowGroupPageIndexReader, error) bufferPool *sync.Pool } @@ -86,12 +86,9 @@ func (r *RowGroupReader) GetColumnPageReader(i int) (PageReader, error) { return nil, err } - if r.rgPageIndexReader == nil { - rgIdx, err := r.pageIndexReader.RowGroup(int(r.rgMetadata.Ordinal())) - if err != nil { - return nil, err - } - r.rgPageIndexReader = rgIdx + rgIdxRdr, err := r.rgPageIndexReader() + if err != nil { + return nil, err } colStart := col.DataPageOffset() @@ -128,7 +125,7 @@ func (r *RowGroupReader) GetColumnPageReader(i int) (PageReader, error) { r: stream, chunk: col, colIdx: i, - pgIndexReader: r.rgPageIndexReader, + pgIndexReader: rgIdxRdr, maxPageHeaderSize: defaultMaxPageHeaderSize, nrows: col.NumValues(), mem: r.props.Allocator(), @@ -157,7 +154,7 @@ func (r *RowGroupReader) GetColumnPageReader(i int) (PageReader, error) { r: stream, chunk: col, colIdx: i, - pgIndexReader: r.rgPageIndexReader, + pgIndexReader: rgIdxRdr, maxPageHeaderSize: defaultMaxPageHeaderSize, nrows: col.NumValues(), mem: r.props.Allocator(), @@ -181,7 +178,7 @@ func (r *RowGroupReader) GetColumnPageReader(i int) (PageReader, error) { r: stream, chunk: col, colIdx: i, - pgIndexReader: r.rgPageIndexReader, + pgIndexReader: rgIdxRdr, maxPageHeaderSize: defaultMaxPageHeaderSize, nrows: col.NumValues(), mem: r.props.Allocator(), diff --git a/parquet/file/row_group_writer.go b/parquet/file/row_group_writer.go index c335a105..f3732c1d 100644 --- a/parquet/file/row_group_writer.go +++ b/parquet/file/row_group_writer.go @@ -80,18 +80,22 @@ type rowGroupWriter struct { columnWriters []ColumnChunkWriter pager PageWriter + + bloomFilters map[string]metadata.BloomFilterBuilder } -func newRowGroupWriter(sink utils.WriterTell, metadata *metadata.RowGroupMetaDataBuilder, ordinal int16, props *parquet.WriterProperties, buffered bool, fileEncryptor encryption.FileEncryptor, pageIdxBldr *metadata.PageIndexBuilder) *rowGroupWriter { +func newRowGroupWriter(sink utils.WriterTell, rgMeta *metadata.RowGroupMetaDataBuilder, ordinal int16, props *parquet.WriterProperties, buffered bool, fileEncryptor encryption.FileEncryptor, pageIdxBldr *metadata.PageIndexBuilder) *rowGroupWriter { ret := &rowGroupWriter{ sink: sink, - metadata: metadata, + metadata: rgMeta, props: props, ordinal: ordinal, buffered: buffered, fileEncryptor: fileEncryptor, pageIndexBuilder: pageIdxBldr, + bloomFilters: make(map[string]metadata.BloomFilterBuilder), } + if buffered { ret.initColumns() } else { @@ -187,6 +191,7 @@ func (rg *rowGroupWriter) NextColumn() (ColumnChunkWriter, error) { } rg.columnWriters[0] = NewColumnChunkWriter(colMeta, rg.pager, rg.props) + rg.bloomFilters[path] = rg.columnWriters[0].GetBloomFilter() return rg.columnWriters[0], nil } @@ -279,7 +284,9 @@ func (rg *rowGroupWriter) initColumns() error { pager.SetIndexBuilders(colIdxBldr, offsetIdxBldr) rg.nextColumnIdx++ - rg.columnWriters = append(rg.columnWriters, NewColumnChunkWriter(colMeta, pager, rg.props)) + cw := NewColumnChunkWriter(colMeta, pager, rg.props) + rg.columnWriters = append(rg.columnWriters, cw) + rg.bloomFilters[path] = cw.GetBloomFilter() } return nil } diff --git a/parquet/internal/bmi/bmi.go b/parquet/internal/bmi/bmi.go index a12af3e7..7139d6fe 100644 --- a/parquet/internal/bmi/bmi.go +++ b/parquet/internal/bmi/bmi.go @@ -19,14 +19,20 @@ // BMI2. package bmi -import "math/bits" +import ( + "math/bits" +) type funcs struct { extractBits func(uint64, uint64) uint64 gtbitmap func([]int16, int16) uint64 } -var funclist funcs +// fallback until arch specific init() is called: +var funclist = funcs{ + extractBits: extractBitsGo, + gtbitmap: greaterThanBitmapGo, +} // ExtractBits performs a Parallel Bit extract as per the PEXT instruction for // x86/x86-64 cpus to use the second parameter as a mask to extract the bits from diff --git a/parquet/internal/encoding/fixed_len_byte_array_encoder.go b/parquet/internal/encoding/fixed_len_byte_array_encoder.go index 3f2a0971..56cf242b 100644 --- a/parquet/internal/encoding/fixed_len_byte_array_encoder.go +++ b/parquet/internal/encoding/fixed_len_byte_array_encoder.go @@ -41,11 +41,15 @@ func (enc *PlainFixedLenByteArrayEncoder) Put(in []parquet.FixedLenByteArray) { bytesNeeded := len(in) * typeLen enc.sink.Reserve(bytesNeeded) + + emptyValue := make([]byte, typeLen) + for _, val := range in { if val == nil { - panic("value cannot be nil") + enc.sink.UnsafeWrite(emptyValue) + } else { + enc.sink.UnsafeWrite(val[:typeLen]) } - enc.sink.UnsafeWrite(val[:typeLen]) } } diff --git a/parquet/internal/encoding/fixed_len_byte_array_encoder_test.go b/parquet/internal/encoding/fixed_len_byte_array_encoder_test.go new file mode 100644 index 00000000..67e83b02 --- /dev/null +++ b/parquet/internal/encoding/fixed_len_byte_array_encoder_test.go @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package encoding + +import ( + "testing" + + "github.com/apache/arrow-go/v18/parquet" + "github.com/apache/arrow-go/v18/parquet/schema" + "github.com/stretchr/testify/require" +) + +func TestPlainFixedLenByteArrayEncoder_Put(t *testing.T) { + sink := NewPooledBufferWriter(0) + elem := schema.NewFixedLenByteArrayNode("test", parquet.Repetitions.Required, 4, 0) + descr := schema.NewColumn(elem, 0, 0) + encoder := &PlainFixedLenByteArrayEncoder{ + encoder: encoder{ + descr: descr, + sink: sink, + }, + } + + tests := []struct { + name string + input []parquet.FixedLenByteArray + expected []byte + }{ + { + name: "Normal input", + input: []parquet.FixedLenByteArray{ + []byte("abcd"), + []byte("efgh"), + []byte("ijkl"), + }, + expected: []byte("abcdefghijkl"), + }, + { + name: "Input with nil values", + input: []parquet.FixedLenByteArray{ + []byte("abcd"), + nil, + []byte("ijkl"), + }, + expected: []byte("abcd\x00\x00\x00\x00ijkl"), // Nil replaced with zero bytes + }, + { + name: "Empty input", + input: []parquet.FixedLenByteArray{}, + expected: []byte{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset the sink before each test + sink.Reset(0) + + // Perform the encoding + encoder.Put(tt.input) + + // Assert the result + require.Equal(t, tt.expected, sink.Bytes(), "Encoded bytes should match expected output") + }) + } +} diff --git a/parquet/internal/encryption/aes.go b/parquet/internal/encryption/aes.go index 0c16facf..06c9c90c 100644 --- a/parquet/internal/encryption/aes.go +++ b/parquet/internal/encryption/aes.go @@ -54,6 +54,8 @@ const ( DictPageHeaderModule ColumnIndexModule OffsetIndexModule + BloomFilterHeaderModule + BloomFilterBitsetModule ) type aesEncryptor struct { diff --git a/parquet/internal/gen-go/parquet/parquet.go b/parquet/internal/gen-go/parquet/parquet.go index 88bc456a..db24d390 100644 --- a/parquet/internal/gen-go/parquet/parquet.go +++ b/parquet/internal/gen-go/parquet/parquet.go @@ -279,6 +279,70 @@ func (p *FieldRepetitionType) Value() (driver.Value, error) { return int64(*p), nil } +//Edge interpolation algorithm for Geography logical type +type EdgeInterpolationAlgorithm int64 +const ( + EdgeInterpolationAlgorithm_SPHERICAL EdgeInterpolationAlgorithm = 0 + EdgeInterpolationAlgorithm_VINCENTY EdgeInterpolationAlgorithm = 1 + EdgeInterpolationAlgorithm_THOMAS EdgeInterpolationAlgorithm = 2 + EdgeInterpolationAlgorithm_ANDOYER EdgeInterpolationAlgorithm = 3 + EdgeInterpolationAlgorithm_KARNEY EdgeInterpolationAlgorithm = 4 +) + +func (p EdgeInterpolationAlgorithm) String() string { + switch p { + case EdgeInterpolationAlgorithm_SPHERICAL: return "SPHERICAL" + case EdgeInterpolationAlgorithm_VINCENTY: return "VINCENTY" + case EdgeInterpolationAlgorithm_THOMAS: return "THOMAS" + case EdgeInterpolationAlgorithm_ANDOYER: return "ANDOYER" + case EdgeInterpolationAlgorithm_KARNEY: return "KARNEY" + } + return "" +} + +func EdgeInterpolationAlgorithmFromString(s string) (EdgeInterpolationAlgorithm, error) { + switch s { + case "SPHERICAL": return EdgeInterpolationAlgorithm_SPHERICAL, nil + case "VINCENTY": return EdgeInterpolationAlgorithm_VINCENTY, nil + case "THOMAS": return EdgeInterpolationAlgorithm_THOMAS, nil + case "ANDOYER": return EdgeInterpolationAlgorithm_ANDOYER, nil + case "KARNEY": return EdgeInterpolationAlgorithm_KARNEY, nil + } + return EdgeInterpolationAlgorithm(0), fmt.Errorf("not a valid EdgeInterpolationAlgorithm string") +} + + +func EdgeInterpolationAlgorithmPtr(v EdgeInterpolationAlgorithm) *EdgeInterpolationAlgorithm { return &v } + +func (p EdgeInterpolationAlgorithm) MarshalText() ([]byte, error) { + return []byte(p.String()), nil +} + +func (p *EdgeInterpolationAlgorithm) UnmarshalText(text []byte) error { + q, err := EdgeInterpolationAlgorithmFromString(string(text)) + if err != nil { + return err + } + *p = q + return nil +} + +func (p *EdgeInterpolationAlgorithm) Scan(value interface{}) error { + v, ok := value.(int64) + if !ok { + return errors.New("Scan value is not int64") + } + *p = EdgeInterpolationAlgorithm(v) + return nil +} + +func (p *EdgeInterpolationAlgorithm) Value() (driver.Value, error) { + if p == nil { + return nil, nil + } + return int64(*p), nil +} + //Encodings supported by Parquet. Not all encodings are valid for all types. These //enums are also used to specify the encoding of definition and repetition levels. //See the accompanying doc for the details of the more complicated encodings. @@ -878,159 +942,119 @@ func (p *SizeStatistics) Validate() error { return nil } -// Statistics per row group and per page -// All fields are optional. +// Bounding box for GEOMETRY or GEOGRAPHY type in the representation of min/max +// value pair of coordinates from each axis. // // Attributes: -// - Max: DEPRECATED: min and max value of the column. Use min_value and max_value. -// -// Values are encoded using PLAIN encoding, except that variable-length byte -// arrays do not include a length prefix. -// -// These fields encode min and max values determined by signed comparison -// only. New files should use the correct order for a column's logical type -// and store the values in the min_value and max_value fields. -// -// To support older readers, these may be set when the column order is -// signed. -// - Min -// - NullCount: Count of null values in the column. -// -// Writers SHOULD always write this field even if it is zero (i.e. no null value) -// or the column is not nullable. -// Readers MUST distinguish between null_count not being present and null_count == 0. -// If null_count is not present, readers MUST NOT assume null_count == 0. -// - DistinctCount: count of distinct values occurring -// - MaxValue: Lower and upper bound values for the column, determined by its ColumnOrder. -// -// These may be the actual minimum and maximum values found on a page or column -// chunk, but can also be (more compact) values that do not exist on a page or -// column chunk. For example, instead of storing "Blart Versenwald III", a writer -// may set min_value="B", max_value="C". Such more compact values must still be -// valid values within the column's logical type. -// -// Values are encoded using PLAIN encoding, except that variable-length byte -// arrays do not include a length prefix. -// - MinValue -// - IsMaxValueExact: If true, max_value is the actual maximum value for a column -// - IsMinValueExact: If true, min_value is the actual minimum value for a column +// - Xmin +// - Xmax +// - Ymin +// - Ymax +// - Zmin +// - Zmax +// - Mmin +// - Mmax // -type Statistics struct { - Max []byte `thrift:"max,1" db:"max" json:"max,omitempty"` - Min []byte `thrift:"min,2" db:"min" json:"min,omitempty"` - NullCount *int64 `thrift:"null_count,3" db:"null_count" json:"null_count,omitempty"` - DistinctCount *int64 `thrift:"distinct_count,4" db:"distinct_count" json:"distinct_count,omitempty"` - MaxValue []byte `thrift:"max_value,5" db:"max_value" json:"max_value,omitempty"` - MinValue []byte `thrift:"min_value,6" db:"min_value" json:"min_value,omitempty"` - IsMaxValueExact *bool `thrift:"is_max_value_exact,7" db:"is_max_value_exact" json:"is_max_value_exact,omitempty"` - IsMinValueExact *bool `thrift:"is_min_value_exact,8" db:"is_min_value_exact" json:"is_min_value_exact,omitempty"` +type BoundingBox struct { + Xmin float64 `thrift:"xmin,1,required" db:"xmin" json:"xmin"` + Xmax float64 `thrift:"xmax,2,required" db:"xmax" json:"xmax"` + Ymin float64 `thrift:"ymin,3,required" db:"ymin" json:"ymin"` + Ymax float64 `thrift:"ymax,4,required" db:"ymax" json:"ymax"` + Zmin *float64 `thrift:"zmin,5" db:"zmin" json:"zmin,omitempty"` + Zmax *float64 `thrift:"zmax,6" db:"zmax" json:"zmax,omitempty"` + Mmin *float64 `thrift:"mmin,7" db:"mmin" json:"mmin,omitempty"` + Mmax *float64 `thrift:"mmax,8" db:"mmax" json:"mmax,omitempty"` } -func NewStatistics() *Statistics { - return &Statistics{} -} - -var Statistics_Max_DEFAULT []byte - - -func (p *Statistics) GetMax() []byte { - return p.Max +func NewBoundingBox() *BoundingBox { + return &BoundingBox{} } -var Statistics_Min_DEFAULT []byte -func (p *Statistics) GetMin() []byte { - return p.Min +func (p *BoundingBox) GetXmin() float64 { + return p.Xmin } -var Statistics_NullCount_DEFAULT int64 - -func (p *Statistics) GetNullCount() int64 { - if !p.IsSetNullCount() { - return Statistics_NullCount_DEFAULT - } - return *p.NullCount -} -var Statistics_DistinctCount_DEFAULT int64 -func (p *Statistics) GetDistinctCount() int64 { - if !p.IsSetDistinctCount() { - return Statistics_DistinctCount_DEFAULT - } - return *p.DistinctCount +func (p *BoundingBox) GetXmax() float64 { + return p.Xmax } -var Statistics_MaxValue_DEFAULT []byte -func (p *Statistics) GetMaxValue() []byte { - return p.MaxValue +func (p *BoundingBox) GetYmin() float64 { + return p.Ymin } -var Statistics_MinValue_DEFAULT []byte -func (p *Statistics) GetMinValue() []byte { - return p.MinValue +func (p *BoundingBox) GetYmax() float64 { + return p.Ymax } -var Statistics_IsMaxValueExact_DEFAULT bool +var BoundingBox_Zmin_DEFAULT float64 -func (p *Statistics) GetIsMaxValueExact() bool { - if !p.IsSetIsMaxValueExact() { - return Statistics_IsMaxValueExact_DEFAULT +func (p *BoundingBox) GetZmin() float64 { + if !p.IsSetZmin() { + return BoundingBox_Zmin_DEFAULT } - return *p.IsMaxValueExact + return *p.Zmin } -var Statistics_IsMinValueExact_DEFAULT bool +var BoundingBox_Zmax_DEFAULT float64 -func (p *Statistics) GetIsMinValueExact() bool { - if !p.IsSetIsMinValueExact() { - return Statistics_IsMinValueExact_DEFAULT +func (p *BoundingBox) GetZmax() float64 { + if !p.IsSetZmax() { + return BoundingBox_Zmax_DEFAULT } - return *p.IsMinValueExact + return *p.Zmax } -func (p *Statistics) IsSetMax() bool { - return p.Max != nil -} +var BoundingBox_Mmin_DEFAULT float64 -func (p *Statistics) IsSetMin() bool { - return p.Min != nil +func (p *BoundingBox) GetMmin() float64 { + if !p.IsSetMmin() { + return BoundingBox_Mmin_DEFAULT + } + return *p.Mmin } -func (p *Statistics) IsSetNullCount() bool { - return p.NullCount != nil -} +var BoundingBox_Mmax_DEFAULT float64 -func (p *Statistics) IsSetDistinctCount() bool { - return p.DistinctCount != nil +func (p *BoundingBox) GetMmax() float64 { + if !p.IsSetMmax() { + return BoundingBox_Mmax_DEFAULT + } + return *p.Mmax } -func (p *Statistics) IsSetMaxValue() bool { - return p.MaxValue != nil +func (p *BoundingBox) IsSetZmin() bool { + return p.Zmin != nil } -func (p *Statistics) IsSetMinValue() bool { - return p.MinValue != nil +func (p *BoundingBox) IsSetZmax() bool { + return p.Zmax != nil } -func (p *Statistics) IsSetIsMaxValueExact() bool { - return p.IsMaxValueExact != nil +func (p *BoundingBox) IsSetMmin() bool { + return p.Mmin != nil } -func (p *Statistics) IsSetIsMinValueExact() bool { - return p.IsMinValueExact != nil +func (p *BoundingBox) IsSetMmax() bool { + return p.Mmax != nil } -func (p *Statistics) Read(ctx context.Context, iprot thrift.TProtocol) error { +func (p *BoundingBox) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } + var issetXmin bool = false; + var issetXmax bool = false; + var issetYmin bool = false; + var issetYmax bool = false; for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) @@ -1042,47 +1066,51 @@ func (p *Statistics) Read(ctx context.Context, iprot thrift.TProtocol) error { } switch fieldId { case 1: - if fieldTypeId == thrift.STRING { + if fieldTypeId == thrift.DOUBLE { if err := p.ReadField1(ctx, iprot); err != nil { return err } + issetXmin = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err } } case 2: - if fieldTypeId == thrift.STRING { + if fieldTypeId == thrift.DOUBLE { if err := p.ReadField2(ctx, iprot); err != nil { return err } + issetXmax = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err } } case 3: - if fieldTypeId == thrift.I64 { + if fieldTypeId == thrift.DOUBLE { if err := p.ReadField3(ctx, iprot); err != nil { return err } + issetYmin = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err } } case 4: - if fieldTypeId == thrift.I64 { + if fieldTypeId == thrift.DOUBLE { if err := p.ReadField4(ctx, iprot); err != nil { return err } + issetYmax = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err } } case 5: - if fieldTypeId == thrift.STRING { + if fieldTypeId == thrift.DOUBLE { if err := p.ReadField5(ctx, iprot); err != nil { return err } @@ -1092,7 +1120,7 @@ func (p *Statistics) Read(ctx context.Context, iprot thrift.TProtocol) error { } } case 6: - if fieldTypeId == thrift.STRING { + if fieldTypeId == thrift.DOUBLE { if err := p.ReadField6(ctx, iprot); err != nil { return err } @@ -1102,7 +1130,7 @@ func (p *Statistics) Read(ctx context.Context, iprot thrift.TProtocol) error { } } case 7: - if fieldTypeId == thrift.BOOL { + if fieldTypeId == thrift.DOUBLE { if err := p.ReadField7(ctx, iprot); err != nil { return err } @@ -1112,7 +1140,7 @@ func (p *Statistics) Read(ctx context.Context, iprot thrift.TProtocol) error { } } case 8: - if fieldTypeId == thrift.BOOL { + if fieldTypeId == thrift.DOUBLE { if err := p.ReadField8(ctx, iprot); err != nil { return err } @@ -1133,83 +1161,95 @@ func (p *Statistics) Read(ctx context.Context, iprot thrift.TProtocol) error { if err := iprot.ReadStructEnd(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } + if !issetXmin{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Xmin is not set")); + } + if !issetXmax{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Xmax is not set")); + } + if !issetYmin{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Ymin is not set")); + } + if !issetYmax{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Ymax is not set")); + } return nil } -func (p *Statistics) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBinary(ctx); err != nil { +func (p *BoundingBox) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { - p.Max = v + p.Xmin = v } return nil } -func (p *Statistics) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBinary(ctx); err != nil { +func (p *BoundingBox) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { - p.Min = v + p.Xmax = v } return nil } -func (p *Statistics) ReadField3(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadI64(ctx); err != nil { +func (p *BoundingBox) ReadField3(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return thrift.PrependError("error reading field 3: ", err) } else { - p.NullCount = &v + p.Ymin = v } return nil } -func (p *Statistics) ReadField4(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadI64(ctx); err != nil { +func (p *BoundingBox) ReadField4(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return thrift.PrependError("error reading field 4: ", err) } else { - p.DistinctCount = &v + p.Ymax = v } return nil } -func (p *Statistics) ReadField5(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBinary(ctx); err != nil { +func (p *BoundingBox) ReadField5(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return thrift.PrependError("error reading field 5: ", err) } else { - p.MaxValue = v + p.Zmin = &v } return nil } -func (p *Statistics) ReadField6(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBinary(ctx); err != nil { +func (p *BoundingBox) ReadField6(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return thrift.PrependError("error reading field 6: ", err) } else { - p.MinValue = v + p.Zmax = &v } return nil } -func (p *Statistics) ReadField7(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBool(ctx); err != nil { +func (p *BoundingBox) ReadField7(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return thrift.PrependError("error reading field 7: ", err) } else { - p.IsMaxValueExact = &v + p.Mmin = &v } return nil } -func (p *Statistics) ReadField8(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBool(ctx); err != nil { +func (p *BoundingBox) ReadField8(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return thrift.PrependError("error reading field 8: ", err) } else { - p.IsMinValueExact = &v + p.Mmax = &v } return nil } -func (p *Statistics) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "Statistics"); err != nil { +func (p *BoundingBox) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "BoundingBox"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { @@ -1231,278 +1271,219 @@ func (p *Statistics) Write(ctx context.Context, oprot thrift.TProtocol) error { return nil } -func (p *Statistics) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetMax() { - if err := oprot.WriteFieldBegin(ctx, "max", thrift.STRING, 1); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:max: ", p), err) - } - if err := oprot.WriteBinary(ctx, p.Max); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.max (1) field write error: ", p), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 1:max: ", p), err) - } +func (p *BoundingBox) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "xmin", thrift.DOUBLE, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:xmin: ", p), err) + } + if err := oprot.WriteDouble(ctx, float64(p.Xmin)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.xmin (1) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:xmin: ", p), err) } return err } -func (p *Statistics) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetMin() { - if err := oprot.WriteFieldBegin(ctx, "min", thrift.STRING, 2); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:min: ", p), err) - } - if err := oprot.WriteBinary(ctx, p.Min); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.min (2) field write error: ", p), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 2:min: ", p), err) - } +func (p *BoundingBox) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "xmax", thrift.DOUBLE, 2); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:xmax: ", p), err) + } + if err := oprot.WriteDouble(ctx, float64(p.Xmax)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.xmax (2) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 2:xmax: ", p), err) } return err } -func (p *Statistics) writeField3(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetNullCount() { - if err := oprot.WriteFieldBegin(ctx, "null_count", thrift.I64, 3); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 3:null_count: ", p), err) +func (p *BoundingBox) writeField3(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "ymin", thrift.DOUBLE, 3); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 3:ymin: ", p), err) + } + if err := oprot.WriteDouble(ctx, float64(p.Ymin)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.ymin (3) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 3:ymin: ", p), err) + } + return err +} + +func (p *BoundingBox) writeField4(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "ymax", thrift.DOUBLE, 4); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 4:ymax: ", p), err) + } + if err := oprot.WriteDouble(ctx, float64(p.Ymax)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.ymax (4) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 4:ymax: ", p), err) + } + return err +} + +func (p *BoundingBox) writeField5(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetZmin() { + if err := oprot.WriteFieldBegin(ctx, "zmin", thrift.DOUBLE, 5); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 5:zmin: ", p), err) } - if err := oprot.WriteI64(ctx, int64(*p.NullCount)); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.null_count (3) field write error: ", p), err) + if err := oprot.WriteDouble(ctx, float64(*p.Zmin)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.zmin (5) field write error: ", p), err) } if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 3:null_count: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field end error 5:zmin: ", p), err) } } return err } -func (p *Statistics) writeField4(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetDistinctCount() { - if err := oprot.WriteFieldBegin(ctx, "distinct_count", thrift.I64, 4); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 4:distinct_count: ", p), err) +func (p *BoundingBox) writeField6(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetZmax() { + if err := oprot.WriteFieldBegin(ctx, "zmax", thrift.DOUBLE, 6); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 6:zmax: ", p), err) } - if err := oprot.WriteI64(ctx, int64(*p.DistinctCount)); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.distinct_count (4) field write error: ", p), err) + if err := oprot.WriteDouble(ctx, float64(*p.Zmax)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.zmax (6) field write error: ", p), err) } if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 4:distinct_count: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field end error 6:zmax: ", p), err) } } return err } -func (p *Statistics) writeField5(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetMaxValue() { - if err := oprot.WriteFieldBegin(ctx, "max_value", thrift.STRING, 5); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 5:max_value: ", p), err) +func (p *BoundingBox) writeField7(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetMmin() { + if err := oprot.WriteFieldBegin(ctx, "mmin", thrift.DOUBLE, 7); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 7:mmin: ", p), err) } - if err := oprot.WriteBinary(ctx, p.MaxValue); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.max_value (5) field write error: ", p), err) + if err := oprot.WriteDouble(ctx, float64(*p.Mmin)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.mmin (7) field write error: ", p), err) } if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 5:max_value: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field end error 7:mmin: ", p), err) } } return err } -func (p *Statistics) writeField6(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetMinValue() { - if err := oprot.WriteFieldBegin(ctx, "min_value", thrift.STRING, 6); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 6:min_value: ", p), err) +func (p *BoundingBox) writeField8(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetMmax() { + if err := oprot.WriteFieldBegin(ctx, "mmax", thrift.DOUBLE, 8); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 8:mmax: ", p), err) } - if err := oprot.WriteBinary(ctx, p.MinValue); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.min_value (6) field write error: ", p), err) + if err := oprot.WriteDouble(ctx, float64(*p.Mmax)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.mmax (8) field write error: ", p), err) } if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 6:min_value: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field end error 8:mmax: ", p), err) } } return err } -func (p *Statistics) writeField7(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetIsMaxValueExact() { - if err := oprot.WriteFieldBegin(ctx, "is_max_value_exact", thrift.BOOL, 7); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 7:is_max_value_exact: ", p), err) - } - if err := oprot.WriteBool(ctx, bool(*p.IsMaxValueExact)); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.is_max_value_exact (7) field write error: ", p), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 7:is_max_value_exact: ", p), err) - } - } - return err -} - -func (p *Statistics) writeField8(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetIsMinValueExact() { - if err := oprot.WriteFieldBegin(ctx, "is_min_value_exact", thrift.BOOL, 8); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 8:is_min_value_exact: ", p), err) - } - if err := oprot.WriteBool(ctx, bool(*p.IsMinValueExact)); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.is_min_value_exact (8) field write error: ", p), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 8:is_min_value_exact: ", p), err) - } - } - return err -} - -func (p *Statistics) Equals(other *Statistics) bool { +func (p *BoundingBox) Equals(other *BoundingBox) bool { if p == other { return true } else if p == nil || other == nil { return false } - if bytes.Compare(p.Max, other.Max) != 0 { return false } - if bytes.Compare(p.Min, other.Min) != 0 { return false } - if p.NullCount != other.NullCount { - if p.NullCount == nil || other.NullCount == nil { + if p.Xmin != other.Xmin { return false } + if p.Xmax != other.Xmax { return false } + if p.Ymin != other.Ymin { return false } + if p.Ymax != other.Ymax { return false } + if p.Zmin != other.Zmin { + if p.Zmin == nil || other.Zmin == nil { return false } - if (*p.NullCount) != (*other.NullCount) { return false } + if (*p.Zmin) != (*other.Zmin) { return false } } - if p.DistinctCount != other.DistinctCount { - if p.DistinctCount == nil || other.DistinctCount == nil { + if p.Zmax != other.Zmax { + if p.Zmax == nil || other.Zmax == nil { return false } - if (*p.DistinctCount) != (*other.DistinctCount) { return false } + if (*p.Zmax) != (*other.Zmax) { return false } } - if bytes.Compare(p.MaxValue, other.MaxValue) != 0 { return false } - if bytes.Compare(p.MinValue, other.MinValue) != 0 { return false } - if p.IsMaxValueExact != other.IsMaxValueExact { - if p.IsMaxValueExact == nil || other.IsMaxValueExact == nil { + if p.Mmin != other.Mmin { + if p.Mmin == nil || other.Mmin == nil { return false } - if (*p.IsMaxValueExact) != (*other.IsMaxValueExact) { return false } + if (*p.Mmin) != (*other.Mmin) { return false } } - if p.IsMinValueExact != other.IsMinValueExact { - if p.IsMinValueExact == nil || other.IsMinValueExact == nil { + if p.Mmax != other.Mmax { + if p.Mmax == nil || other.Mmax == nil { return false } - if (*p.IsMinValueExact) != (*other.IsMinValueExact) { return false } + if (*p.Mmax) != (*other.Mmax) { return false } } return true } -func (p *Statistics) String() string { +func (p *BoundingBox) String() string { if p == nil { return "" } - return fmt.Sprintf("Statistics(%+v)", *p) + return fmt.Sprintf("BoundingBox(%+v)", *p) } -func (p *Statistics) LogValue() slog.Value { +func (p *BoundingBox) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.Statistics", + Type: "*parquet.BoundingBox", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*Statistics)(nil) +var _ slog.LogValuer = (*BoundingBox)(nil) -func (p *Statistics) Validate() error { +func (p *BoundingBox) Validate() error { return nil } -// Empty structs to use as logical type annotations -type StringType struct { -} - -func NewStringType() *StringType { - return &StringType{} -} - -func (p *StringType) Read(ctx context.Context, iprot thrift.TProtocol) error { - if _, err := iprot.ReadStructBegin(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) - } - - - for { - _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) - if err != nil { - return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) - } - if fieldTypeId == thrift.STOP { - break - } - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err - } - if err := iprot.ReadFieldEnd(ctx); err != nil { - return err - } - } - if err := iprot.ReadStructEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) - } - return nil +// Statistics specific to Geometry and Geography logical types +// +// Attributes: +// - Bbox: A bounding box of geospatial instances +// - GeospatialTypes: Geospatial type codes of all instances, or an empty list if not known +// +type GeospatialStatistics struct { + Bbox *BoundingBox `thrift:"bbox,1" db:"bbox" json:"bbox,omitempty"` + GeospatialTypes []int32 `thrift:"geospatial_types,2" db:"geospatial_types" json:"geospatial_types,omitempty"` } -func (p *StringType) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "StringType"); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) - } - if p != nil { - } - if err := oprot.WriteFieldStop(ctx); err != nil { - return thrift.PrependError("write field stop error: ", err) - } - if err := oprot.WriteStructEnd(ctx); err != nil { - return thrift.PrependError("write struct stop error: ", err) - } - return nil +func NewGeospatialStatistics() *GeospatialStatistics { + return &GeospatialStatistics{} } -func (p *StringType) Equals(other *StringType) bool { - if p == other { - return true - } else if p == nil || other == nil { - return false - } - return true -} +var GeospatialStatistics_Bbox_DEFAULT *BoundingBox -func (p *StringType) String() string { - if p == nil { - return "" +func (p *GeospatialStatistics) GetBbox() *BoundingBox { + if !p.IsSetBbox() { + return GeospatialStatistics_Bbox_DEFAULT } - return fmt.Sprintf("StringType(%+v)", *p) + return p.Bbox } -func (p *StringType) LogValue() slog.Value { - if p == nil { - return slog.AnyValue(nil) - } - v := thrift.SlogTStructWrapper{ - Type: "*parquet.StringType", - Value: p, - } - return slog.AnyValue(v) -} +var GeospatialStatistics_GeospatialTypes_DEFAULT []int32 -var _ slog.LogValuer = (*StringType)(nil) -func (p *StringType) Validate() error { - return nil +func (p *GeospatialStatistics) GetGeospatialTypes() []int32 { + return p.GeospatialTypes } -type UUIDType struct { +func (p *GeospatialStatistics) IsSetBbox() bool { + return p.Bbox != nil } -func NewUUIDType() *UUIDType { - return &UUIDType{} +func (p *GeospatialStatistics) IsSetGeospatialTypes() bool { + return p.GeospatialTypes != nil } -func (p *UUIDType) Read(ctx context.Context, iprot thrift.TProtocol) error { +func (p *GeospatialStatistics) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } @@ -1516,8 +1497,31 @@ func (p *UUIDType) Read(ctx context.Context, iprot thrift.TProtocol) error { if fieldTypeId == thrift.STOP { break } - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField1(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 2: + if fieldTypeId == thrift.LIST { + if err := p.ReadField2(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + default: + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } } if err := iprot.ReadFieldEnd(ctx); err != nil { return err @@ -1529,11 +1533,43 @@ func (p *UUIDType) Read(ctx context.Context, iprot thrift.TProtocol) error { return nil } -func (p *UUIDType) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "UUIDType"); err != nil { +func (p *GeospatialStatistics) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + p.Bbox = &BoundingBox{} + if err := p.Bbox.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Bbox), err) + } + return nil +} + +func (p *GeospatialStatistics) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin(ctx) + if err != nil { + return thrift.PrependError("error reading list begin: ", err) + } + tSlice := make([]int32, 0, size) + p.GeospatialTypes = tSlice + for i := 0; i < size; i++ { + var _elem4 int32 + if v, err := iprot.ReadI32(ctx); err != nil { + return thrift.PrependError("error reading field 0: ", err) + } else { + _elem4 = v + } + p.GeospatialTypes = append(p.GeospatialTypes, _elem4) + } + if err := iprot.ReadListEnd(ctx); err != nil { + return thrift.PrependError("error reading list end: ", err) + } + return nil +} + +func (p *GeospatialStatistics) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "GeospatialStatistics"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { + if err := p.writeField1(ctx, oprot); err != nil { return err } + if err := p.writeField2(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -1544,43 +1580,786 @@ func (p *UUIDType) Write(ctx context.Context, oprot thrift.TProtocol) error { return nil } -func (p *UUIDType) Equals(other *UUIDType) bool { +func (p *GeospatialStatistics) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetBbox() { + if err := oprot.WriteFieldBegin(ctx, "bbox", thrift.STRUCT, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:bbox: ", p), err) + } + if err := p.Bbox.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Bbox), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:bbox: ", p), err) + } + } + return err +} + +func (p *GeospatialStatistics) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetGeospatialTypes() { + if err := oprot.WriteFieldBegin(ctx, "geospatial_types", thrift.LIST, 2); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:geospatial_types: ", p), err) + } + if err := oprot.WriteListBegin(ctx, thrift.I32, len(p.GeospatialTypes)); err != nil { + return thrift.PrependError("error writing list begin: ", err) + } + for _, v := range p.GeospatialTypes { + if err := oprot.WriteI32(ctx, int32(v)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) + } + } + if err := oprot.WriteListEnd(ctx); err != nil { + return thrift.PrependError("error writing list end: ", err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 2:geospatial_types: ", p), err) + } + } + return err +} + +func (p *GeospatialStatistics) Equals(other *GeospatialStatistics) bool { if p == other { return true } else if p == nil || other == nil { return false } + if !p.Bbox.Equals(other.Bbox) { return false } + if len(p.GeospatialTypes) != len(other.GeospatialTypes) { return false } + for i, _tgt := range p.GeospatialTypes { + _src5 := other.GeospatialTypes[i] + if _tgt != _src5 { return false } + } return true } -func (p *UUIDType) String() string { +func (p *GeospatialStatistics) String() string { if p == nil { return "" } - return fmt.Sprintf("UUIDType(%+v)", *p) + return fmt.Sprintf("GeospatialStatistics(%+v)", *p) } -func (p *UUIDType) LogValue() slog.Value { +func (p *GeospatialStatistics) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.UUIDType", + Type: "*parquet.GeospatialStatistics", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*UUIDType)(nil) +var _ slog.LogValuer = (*GeospatialStatistics)(nil) -func (p *UUIDType) Validate() error { +func (p *GeospatialStatistics) Validate() error { return nil } -type MapType struct { +// Statistics per row group and per page +// All fields are optional. +// +// Attributes: +// - Max: DEPRECATED: min and max value of the column. Use min_value and max_value. +// +// Values are encoded using PLAIN encoding, except that variable-length byte +// arrays do not include a length prefix. +// +// These fields encode min and max values determined by signed comparison +// only. New files should use the correct order for a column's logical type +// and store the values in the min_value and max_value fields. +// +// To support older readers, these may be set when the column order is +// signed. +// - Min +// - NullCount: Count of null values in the column. +// +// Writers SHOULD always write this field even if it is zero (i.e. no null value) +// or the column is not nullable. +// Readers MUST distinguish between null_count not being present and null_count == 0. +// If null_count is not present, readers MUST NOT assume null_count == 0. +// - DistinctCount: count of distinct values occurring +// - MaxValue: Lower and upper bound values for the column, determined by its ColumnOrder. +// +// These may be the actual minimum and maximum values found on a page or column +// chunk, but can also be (more compact) values that do not exist on a page or +// column chunk. For example, instead of storing "Blart Versenwald III", a writer +// may set min_value="B", max_value="C". Such more compact values must still be +// valid values within the column's logical type. +// +// Values are encoded using PLAIN encoding, except that variable-length byte +// arrays do not include a length prefix. +// - MinValue +// - IsMaxValueExact: If true, max_value is the actual maximum value for a column +// - IsMinValueExact: If true, min_value is the actual minimum value for a column +// +type Statistics struct { + Max []byte `thrift:"max,1" db:"max" json:"max,omitempty"` + Min []byte `thrift:"min,2" db:"min" json:"min,omitempty"` + NullCount *int64 `thrift:"null_count,3" db:"null_count" json:"null_count,omitempty"` + DistinctCount *int64 `thrift:"distinct_count,4" db:"distinct_count" json:"distinct_count,omitempty"` + MaxValue []byte `thrift:"max_value,5" db:"max_value" json:"max_value,omitempty"` + MinValue []byte `thrift:"min_value,6" db:"min_value" json:"min_value,omitempty"` + IsMaxValueExact *bool `thrift:"is_max_value_exact,7" db:"is_max_value_exact" json:"is_max_value_exact,omitempty"` + IsMinValueExact *bool `thrift:"is_min_value_exact,8" db:"is_min_value_exact" json:"is_min_value_exact,omitempty"` } -func NewMapType() *MapType { +func NewStatistics() *Statistics { + return &Statistics{} +} + +var Statistics_Max_DEFAULT []byte + + +func (p *Statistics) GetMax() []byte { + return p.Max +} + +var Statistics_Min_DEFAULT []byte + + +func (p *Statistics) GetMin() []byte { + return p.Min +} + +var Statistics_NullCount_DEFAULT int64 + +func (p *Statistics) GetNullCount() int64 { + if !p.IsSetNullCount() { + return Statistics_NullCount_DEFAULT + } + return *p.NullCount +} + +var Statistics_DistinctCount_DEFAULT int64 + +func (p *Statistics) GetDistinctCount() int64 { + if !p.IsSetDistinctCount() { + return Statistics_DistinctCount_DEFAULT + } + return *p.DistinctCount +} + +var Statistics_MaxValue_DEFAULT []byte + + +func (p *Statistics) GetMaxValue() []byte { + return p.MaxValue +} + +var Statistics_MinValue_DEFAULT []byte + + +func (p *Statistics) GetMinValue() []byte { + return p.MinValue +} + +var Statistics_IsMaxValueExact_DEFAULT bool + +func (p *Statistics) GetIsMaxValueExact() bool { + if !p.IsSetIsMaxValueExact() { + return Statistics_IsMaxValueExact_DEFAULT + } + return *p.IsMaxValueExact +} + +var Statistics_IsMinValueExact_DEFAULT bool + +func (p *Statistics) GetIsMinValueExact() bool { + if !p.IsSetIsMinValueExact() { + return Statistics_IsMinValueExact_DEFAULT + } + return *p.IsMinValueExact +} + +func (p *Statistics) IsSetMax() bool { + return p.Max != nil +} + +func (p *Statistics) IsSetMin() bool { + return p.Min != nil +} + +func (p *Statistics) IsSetNullCount() bool { + return p.NullCount != nil +} + +func (p *Statistics) IsSetDistinctCount() bool { + return p.DistinctCount != nil +} + +func (p *Statistics) IsSetMaxValue() bool { + return p.MaxValue != nil +} + +func (p *Statistics) IsSetMinValue() bool { + return p.MinValue != nil +} + +func (p *Statistics) IsSetIsMaxValueExact() bool { + return p.IsMaxValueExact != nil +} + +func (p *Statistics) IsSetIsMinValueExact() bool { + return p.IsMinValueExact != nil +} + +func (p *Statistics) Read(ctx context.Context, iprot thrift.TProtocol) error { + if _, err := iprot.ReadStructBegin(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) + } + + + for { + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) + if err != nil { + return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err := p.ReadField1(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 2: + if fieldTypeId == thrift.STRING { + if err := p.ReadField2(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 3: + if fieldTypeId == thrift.I64 { + if err := p.ReadField3(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 4: + if fieldTypeId == thrift.I64 { + if err := p.ReadField4(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 5: + if fieldTypeId == thrift.STRING { + if err := p.ReadField5(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 6: + if fieldTypeId == thrift.STRING { + if err := p.ReadField6(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 7: + if fieldTypeId == thrift.BOOL { + if err := p.ReadField7(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 8: + if fieldTypeId == thrift.BOOL { + if err := p.ReadField8(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + default: + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + if err := iprot.ReadFieldEnd(ctx); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + } + return nil +} + +func (p *Statistics) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBinary(ctx); err != nil { + return thrift.PrependError("error reading field 1: ", err) + } else { + p.Max = v + } + return nil +} + +func (p *Statistics) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBinary(ctx); err != nil { + return thrift.PrependError("error reading field 2: ", err) + } else { + p.Min = v + } + return nil +} + +func (p *Statistics) ReadField3(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadI64(ctx); err != nil { + return thrift.PrependError("error reading field 3: ", err) + } else { + p.NullCount = &v + } + return nil +} + +func (p *Statistics) ReadField4(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadI64(ctx); err != nil { + return thrift.PrependError("error reading field 4: ", err) + } else { + p.DistinctCount = &v + } + return nil +} + +func (p *Statistics) ReadField5(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBinary(ctx); err != nil { + return thrift.PrependError("error reading field 5: ", err) + } else { + p.MaxValue = v + } + return nil +} + +func (p *Statistics) ReadField6(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBinary(ctx); err != nil { + return thrift.PrependError("error reading field 6: ", err) + } else { + p.MinValue = v + } + return nil +} + +func (p *Statistics) ReadField7(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBool(ctx); err != nil { + return thrift.PrependError("error reading field 7: ", err) + } else { + p.IsMaxValueExact = &v + } + return nil +} + +func (p *Statistics) ReadField8(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBool(ctx); err != nil { + return thrift.PrependError("error reading field 8: ", err) + } else { + p.IsMinValueExact = &v + } + return nil +} + +func (p *Statistics) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "Statistics"); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + } + if p != nil { + if err := p.writeField1(ctx, oprot); err != nil { return err } + if err := p.writeField2(ctx, oprot); err != nil { return err } + if err := p.writeField3(ctx, oprot); err != nil { return err } + if err := p.writeField4(ctx, oprot); err != nil { return err } + if err := p.writeField5(ctx, oprot); err != nil { return err } + if err := p.writeField6(ctx, oprot); err != nil { return err } + if err := p.writeField7(ctx, oprot); err != nil { return err } + if err := p.writeField8(ctx, oprot); err != nil { return err } + } + if err := oprot.WriteFieldStop(ctx); err != nil { + return thrift.PrependError("write field stop error: ", err) + } + if err := oprot.WriteStructEnd(ctx); err != nil { + return thrift.PrependError("write struct stop error: ", err) + } + return nil +} + +func (p *Statistics) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetMax() { + if err := oprot.WriteFieldBegin(ctx, "max", thrift.STRING, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:max: ", p), err) + } + if err := oprot.WriteBinary(ctx, p.Max); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.max (1) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:max: ", p), err) + } + } + return err +} + +func (p *Statistics) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetMin() { + if err := oprot.WriteFieldBegin(ctx, "min", thrift.STRING, 2); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:min: ", p), err) + } + if err := oprot.WriteBinary(ctx, p.Min); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.min (2) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 2:min: ", p), err) + } + } + return err +} + +func (p *Statistics) writeField3(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetNullCount() { + if err := oprot.WriteFieldBegin(ctx, "null_count", thrift.I64, 3); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 3:null_count: ", p), err) + } + if err := oprot.WriteI64(ctx, int64(*p.NullCount)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.null_count (3) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 3:null_count: ", p), err) + } + } + return err +} + +func (p *Statistics) writeField4(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetDistinctCount() { + if err := oprot.WriteFieldBegin(ctx, "distinct_count", thrift.I64, 4); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 4:distinct_count: ", p), err) + } + if err := oprot.WriteI64(ctx, int64(*p.DistinctCount)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.distinct_count (4) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 4:distinct_count: ", p), err) + } + } + return err +} + +func (p *Statistics) writeField5(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetMaxValue() { + if err := oprot.WriteFieldBegin(ctx, "max_value", thrift.STRING, 5); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 5:max_value: ", p), err) + } + if err := oprot.WriteBinary(ctx, p.MaxValue); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.max_value (5) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 5:max_value: ", p), err) + } + } + return err +} + +func (p *Statistics) writeField6(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetMinValue() { + if err := oprot.WriteFieldBegin(ctx, "min_value", thrift.STRING, 6); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 6:min_value: ", p), err) + } + if err := oprot.WriteBinary(ctx, p.MinValue); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.min_value (6) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 6:min_value: ", p), err) + } + } + return err +} + +func (p *Statistics) writeField7(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetIsMaxValueExact() { + if err := oprot.WriteFieldBegin(ctx, "is_max_value_exact", thrift.BOOL, 7); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 7:is_max_value_exact: ", p), err) + } + if err := oprot.WriteBool(ctx, bool(*p.IsMaxValueExact)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.is_max_value_exact (7) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 7:is_max_value_exact: ", p), err) + } + } + return err +} + +func (p *Statistics) writeField8(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetIsMinValueExact() { + if err := oprot.WriteFieldBegin(ctx, "is_min_value_exact", thrift.BOOL, 8); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 8:is_min_value_exact: ", p), err) + } + if err := oprot.WriteBool(ctx, bool(*p.IsMinValueExact)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.is_min_value_exact (8) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 8:is_min_value_exact: ", p), err) + } + } + return err +} + +func (p *Statistics) Equals(other *Statistics) bool { + if p == other { + return true + } else if p == nil || other == nil { + return false + } + if bytes.Compare(p.Max, other.Max) != 0 { return false } + if bytes.Compare(p.Min, other.Min) != 0 { return false } + if p.NullCount != other.NullCount { + if p.NullCount == nil || other.NullCount == nil { + return false + } + if (*p.NullCount) != (*other.NullCount) { return false } + } + if p.DistinctCount != other.DistinctCount { + if p.DistinctCount == nil || other.DistinctCount == nil { + return false + } + if (*p.DistinctCount) != (*other.DistinctCount) { return false } + } + if bytes.Compare(p.MaxValue, other.MaxValue) != 0 { return false } + if bytes.Compare(p.MinValue, other.MinValue) != 0 { return false } + if p.IsMaxValueExact != other.IsMaxValueExact { + if p.IsMaxValueExact == nil || other.IsMaxValueExact == nil { + return false + } + if (*p.IsMaxValueExact) != (*other.IsMaxValueExact) { return false } + } + if p.IsMinValueExact != other.IsMinValueExact { + if p.IsMinValueExact == nil || other.IsMinValueExact == nil { + return false + } + if (*p.IsMinValueExact) != (*other.IsMinValueExact) { return false } + } + return true +} + +func (p *Statistics) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("Statistics(%+v)", *p) +} + +func (p *Statistics) LogValue() slog.Value { + if p == nil { + return slog.AnyValue(nil) + } + v := thrift.SlogTStructWrapper{ + Type: "*parquet.Statistics", + Value: p, + } + return slog.AnyValue(v) +} + +var _ slog.LogValuer = (*Statistics)(nil) + +func (p *Statistics) Validate() error { + return nil +} + +// Empty structs to use as logical type annotations +type StringType struct { +} + +func NewStringType() *StringType { + return &StringType{} +} + +func (p *StringType) Read(ctx context.Context, iprot thrift.TProtocol) error { + if _, err := iprot.ReadStructBegin(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) + } + + + for { + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) + if err != nil { + return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) + } + if fieldTypeId == thrift.STOP { + break + } + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + if err := iprot.ReadFieldEnd(ctx); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + } + return nil +} + +func (p *StringType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "StringType"); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + } + if p != nil { + } + if err := oprot.WriteFieldStop(ctx); err != nil { + return thrift.PrependError("write field stop error: ", err) + } + if err := oprot.WriteStructEnd(ctx); err != nil { + return thrift.PrependError("write struct stop error: ", err) + } + return nil +} + +func (p *StringType) Equals(other *StringType) bool { + if p == other { + return true + } else if p == nil || other == nil { + return false + } + return true +} + +func (p *StringType) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("StringType(%+v)", *p) +} + +func (p *StringType) LogValue() slog.Value { + if p == nil { + return slog.AnyValue(nil) + } + v := thrift.SlogTStructWrapper{ + Type: "*parquet.StringType", + Value: p, + } + return slog.AnyValue(v) +} + +var _ slog.LogValuer = (*StringType)(nil) + +func (p *StringType) Validate() error { + return nil +} + +type UUIDType struct { +} + +func NewUUIDType() *UUIDType { + return &UUIDType{} +} + +func (p *UUIDType) Read(ctx context.Context, iprot thrift.TProtocol) error { + if _, err := iprot.ReadStructBegin(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) + } + + + for { + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) + if err != nil { + return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) + } + if fieldTypeId == thrift.STOP { + break + } + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + if err := iprot.ReadFieldEnd(ctx); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + } + return nil +} + +func (p *UUIDType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "UUIDType"); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + } + if p != nil { + } + if err := oprot.WriteFieldStop(ctx); err != nil { + return thrift.PrependError("write field stop error: ", err) + } + if err := oprot.WriteStructEnd(ctx); err != nil { + return thrift.PrependError("write struct stop error: ", err) + } + return nil +} + +func (p *UUIDType) Equals(other *UUIDType) bool { + if p == other { + return true + } else if p == nil || other == nil { + return false + } + return true +} + +func (p *UUIDType) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("UUIDType(%+v)", *p) +} + +func (p *UUIDType) LogValue() slog.Value { + if p == nil { + return slog.AnyValue(nil) + } + v := thrift.SlogTStructWrapper{ + Type: "*parquet.UUIDType", + Value: p, + } + return slog.AnyValue(v) +} + +var _ slog.LogValuer = (*UUIDType)(nil) + +func (p *UUIDType) Validate() error { + return nil +} + +type MapType struct { +} + +func NewMapType() *MapType { return &MapType{} } @@ -2315,7 +3094,171 @@ func (p *MilliSeconds) Write(ctx context.Context, oprot thrift.TProtocol) error return nil } -func (p *MilliSeconds) Equals(other *MilliSeconds) bool { +func (p *MilliSeconds) Equals(other *MilliSeconds) bool { + if p == other { + return true + } else if p == nil || other == nil { + return false + } + return true +} + +func (p *MilliSeconds) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("MilliSeconds(%+v)", *p) +} + +func (p *MilliSeconds) LogValue() slog.Value { + if p == nil { + return slog.AnyValue(nil) + } + v := thrift.SlogTStructWrapper{ + Type: "*parquet.MilliSeconds", + Value: p, + } + return slog.AnyValue(v) +} + +var _ slog.LogValuer = (*MilliSeconds)(nil) + +func (p *MilliSeconds) Validate() error { + return nil +} + +type MicroSeconds struct { +} + +func NewMicroSeconds() *MicroSeconds { + return &MicroSeconds{} +} + +func (p *MicroSeconds) Read(ctx context.Context, iprot thrift.TProtocol) error { + if _, err := iprot.ReadStructBegin(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) + } + + + for { + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) + if err != nil { + return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) + } + if fieldTypeId == thrift.STOP { + break + } + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + if err := iprot.ReadFieldEnd(ctx); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + } + return nil +} + +func (p *MicroSeconds) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "MicroSeconds"); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + } + if p != nil { + } + if err := oprot.WriteFieldStop(ctx); err != nil { + return thrift.PrependError("write field stop error: ", err) + } + if err := oprot.WriteStructEnd(ctx); err != nil { + return thrift.PrependError("write struct stop error: ", err) + } + return nil +} + +func (p *MicroSeconds) Equals(other *MicroSeconds) bool { + if p == other { + return true + } else if p == nil || other == nil { + return false + } + return true +} + +func (p *MicroSeconds) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("MicroSeconds(%+v)", *p) +} + +func (p *MicroSeconds) LogValue() slog.Value { + if p == nil { + return slog.AnyValue(nil) + } + v := thrift.SlogTStructWrapper{ + Type: "*parquet.MicroSeconds", + Value: p, + } + return slog.AnyValue(v) +} + +var _ slog.LogValuer = (*MicroSeconds)(nil) + +func (p *MicroSeconds) Validate() error { + return nil +} + +type NanoSeconds struct { +} + +func NewNanoSeconds() *NanoSeconds { + return &NanoSeconds{} +} + +func (p *NanoSeconds) Read(ctx context.Context, iprot thrift.TProtocol) error { + if _, err := iprot.ReadStructBegin(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) + } + + + for { + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) + if err != nil { + return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) + } + if fieldTypeId == thrift.STOP { + break + } + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + if err := iprot.ReadFieldEnd(ctx); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + } + return nil +} + +func (p *NanoSeconds) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "NanoSeconds"); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + } + if p != nil { + } + if err := oprot.WriteFieldStop(ctx); err != nil { + return thrift.PrependError("write field stop error: ", err) + } + if err := oprot.WriteStructEnd(ctx); err != nil { + return thrift.PrependError("write struct stop error: ", err) + } + return nil +} + +func (p *NanoSeconds) Equals(other *NanoSeconds) bool { if p == other { return true } else if p == nil || other == nil { @@ -2324,38 +3267,100 @@ func (p *MilliSeconds) Equals(other *MilliSeconds) bool { return true } -func (p *MilliSeconds) String() string { +func (p *NanoSeconds) String() string { if p == nil { return "" } - return fmt.Sprintf("MilliSeconds(%+v)", *p) + return fmt.Sprintf("NanoSeconds(%+v)", *p) } -func (p *MilliSeconds) LogValue() slog.Value { +func (p *NanoSeconds) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.MilliSeconds", + Type: "*parquet.NanoSeconds", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*MilliSeconds)(nil) +var _ slog.LogValuer = (*NanoSeconds)(nil) -func (p *MilliSeconds) Validate() error { +func (p *NanoSeconds) Validate() error { return nil } -type MicroSeconds struct { +// Attributes: +// - MILLIS +// - MICROS +// - NANOS +// +type TimeUnit struct { + MILLIS *MilliSeconds `thrift:"MILLIS,1" db:"MILLIS" json:"MILLIS,omitempty"` + MICROS *MicroSeconds `thrift:"MICROS,2" db:"MICROS" json:"MICROS,omitempty"` + NANOS *NanoSeconds `thrift:"NANOS,3" db:"NANOS" json:"NANOS,omitempty"` } -func NewMicroSeconds() *MicroSeconds { - return &MicroSeconds{} +func NewTimeUnit() *TimeUnit { + return &TimeUnit{} } -func (p *MicroSeconds) Read(ctx context.Context, iprot thrift.TProtocol) error { +var TimeUnit_MILLIS_DEFAULT *MilliSeconds + +func (p *TimeUnit) GetMILLIS() *MilliSeconds { + if !p.IsSetMILLIS() { + return TimeUnit_MILLIS_DEFAULT + } + return p.MILLIS +} + +var TimeUnit_MICROS_DEFAULT *MicroSeconds + +func (p *TimeUnit) GetMICROS() *MicroSeconds { + if !p.IsSetMICROS() { + return TimeUnit_MICROS_DEFAULT + } + return p.MICROS +} + +var TimeUnit_NANOS_DEFAULT *NanoSeconds + +func (p *TimeUnit) GetNANOS() *NanoSeconds { + if !p.IsSetNANOS() { + return TimeUnit_NANOS_DEFAULT + } + return p.NANOS +} + +func (p *TimeUnit) CountSetFieldsTimeUnit() int { + count := 0 + if (p.IsSetMILLIS()) { + count++ + } + if (p.IsSetMICROS()) { + count++ + } + if (p.IsSetNANOS()) { + count++ + } + return count + +} + +func (p *TimeUnit) IsSetMILLIS() bool { + return p.MILLIS != nil +} + +func (p *TimeUnit) IsSetMICROS() bool { + return p.MICROS != nil +} + +func (p *TimeUnit) IsSetNANOS() bool { + return p.NANOS != nil +} + +func (p *TimeUnit) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } @@ -2369,8 +3374,41 @@ func (p *MicroSeconds) Read(ctx context.Context, iprot thrift.TProtocol) error { if fieldTypeId == thrift.STOP { break } - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err + switch fieldId { + case 1: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField1(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 2: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField2(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 3: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField3(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + default: + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } } if err := iprot.ReadFieldEnd(ctx); err != nil { return err @@ -2382,11 +3420,41 @@ func (p *MicroSeconds) Read(ctx context.Context, iprot thrift.TProtocol) error { return nil } -func (p *MicroSeconds) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "MicroSeconds"); err != nil { +func (p *TimeUnit) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + p.MILLIS = &MilliSeconds{} + if err := p.MILLIS.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.MILLIS), err) + } + return nil +} + +func (p *TimeUnit) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { + p.MICROS = &MicroSeconds{} + if err := p.MICROS.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.MICROS), err) + } + return nil +} + +func (p *TimeUnit) ReadField3(ctx context.Context, iprot thrift.TProtocol) error { + p.NANOS = &NanoSeconds{} + if err := p.NANOS.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.NANOS), err) + } + return nil +} + +func (p *TimeUnit) Write(ctx context.Context, oprot thrift.TProtocol) error { + if c := p.CountSetFieldsTimeUnit(); c != 1 { + return fmt.Errorf("%T write union: exactly one field must be set (%d set)", p, c) + } + if err := oprot.WriteStructBegin(ctx, "TimeUnit"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { + if err := p.writeField1(ctx, oprot); err != nil { return err } + if err := p.writeField2(ctx, oprot); err != nil { return err } + if err := p.writeField3(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -2397,51 +3465,130 @@ func (p *MicroSeconds) Write(ctx context.Context, oprot thrift.TProtocol) error return nil } -func (p *MicroSeconds) Equals(other *MicroSeconds) bool { +func (p *TimeUnit) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetMILLIS() { + if err := oprot.WriteFieldBegin(ctx, "MILLIS", thrift.STRUCT, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:MILLIS: ", p), err) + } + if err := p.MILLIS.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.MILLIS), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:MILLIS: ", p), err) + } + } + return err +} + +func (p *TimeUnit) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetMICROS() { + if err := oprot.WriteFieldBegin(ctx, "MICROS", thrift.STRUCT, 2); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:MICROS: ", p), err) + } + if err := p.MICROS.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.MICROS), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 2:MICROS: ", p), err) + } + } + return err +} + +func (p *TimeUnit) writeField3(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetNANOS() { + if err := oprot.WriteFieldBegin(ctx, "NANOS", thrift.STRUCT, 3); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 3:NANOS: ", p), err) + } + if err := p.NANOS.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.NANOS), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 3:NANOS: ", p), err) + } + } + return err +} + +func (p *TimeUnit) Equals(other *TimeUnit) bool { if p == other { return true } else if p == nil || other == nil { return false } + if !p.MILLIS.Equals(other.MILLIS) { return false } + if !p.MICROS.Equals(other.MICROS) { return false } + if !p.NANOS.Equals(other.NANOS) { return false } return true } -func (p *MicroSeconds) String() string { +func (p *TimeUnit) String() string { if p == nil { return "" } - return fmt.Sprintf("MicroSeconds(%+v)", *p) + return fmt.Sprintf("TimeUnit(%+v)", *p) } -func (p *MicroSeconds) LogValue() slog.Value { +func (p *TimeUnit) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.MicroSeconds", + Type: "*parquet.TimeUnit", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*MicroSeconds)(nil) +var _ slog.LogValuer = (*TimeUnit)(nil) -func (p *MicroSeconds) Validate() error { +func (p *TimeUnit) Validate() error { return nil } -type NanoSeconds struct { +// Timestamp logical type annotation +// +// Allowed for physical types: INT64 +// +// Attributes: +// - IsAdjustedToUTC +// - Unit +// +type TimestampType struct { + IsAdjustedToUTC bool `thrift:"isAdjustedToUTC,1,required" db:"isAdjustedToUTC" json:"isAdjustedToUTC"` + Unit *TimeUnit `thrift:"unit,2,required" db:"unit" json:"unit"` } -func NewNanoSeconds() *NanoSeconds { - return &NanoSeconds{} +func NewTimestampType() *TimestampType { + return &TimestampType{} } -func (p *NanoSeconds) Read(ctx context.Context, iprot thrift.TProtocol) error { + + +func (p *TimestampType) GetIsAdjustedToUTC() bool { + return p.IsAdjustedToUTC +} + +var TimestampType_Unit_DEFAULT *TimeUnit + +func (p *TimestampType) GetUnit() *TimeUnit { + if !p.IsSetUnit() { + return TimestampType_Unit_DEFAULT + } + return p.Unit +} + +func (p *TimestampType) IsSetUnit() bool { + return p.Unit != nil +} + +func (p *TimestampType) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } + var issetIsAdjustedToUTC bool = false; + var issetUnit bool = false; for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) @@ -2451,8 +3598,33 @@ func (p *NanoSeconds) Read(ctx context.Context, iprot thrift.TProtocol) error { if fieldTypeId == thrift.STOP { break } - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err + switch fieldId { + case 1: + if fieldTypeId == thrift.BOOL { + if err := p.ReadField1(ctx, iprot); err != nil { + return err + } + issetIsAdjustedToUTC = true + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 2: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField2(ctx, iprot); err != nil { + return err + } + issetUnit = true + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + default: + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } } if err := iprot.ReadFieldEnd(ctx); err != nil { return err @@ -2461,14 +3633,39 @@ func (p *NanoSeconds) Read(ctx context.Context, iprot thrift.TProtocol) error { if err := iprot.ReadStructEnd(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } + if !issetIsAdjustedToUTC{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field IsAdjustedToUTC is not set")); + } + if !issetUnit{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Unit is not set")); + } return nil } -func (p *NanoSeconds) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "NanoSeconds"); err != nil { +func (p *TimestampType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBool(ctx); err != nil { + return thrift.PrependError("error reading field 1: ", err) + } else { + p.IsAdjustedToUTC = v + } + return nil +} + +func (p *TimestampType) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { + p.Unit = &TimeUnit{} + if err := p.Unit.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Unit), err) + } + return nil +} + +func (p *TimestampType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "TimestampType"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { + if err := p.writeField1(ctx, oprot); err != nil { return err } + if err := p.writeField2(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -2479,113 +3676,110 @@ func (p *NanoSeconds) Write(ctx context.Context, oprot thrift.TProtocol) error { return nil } -func (p *NanoSeconds) Equals(other *NanoSeconds) bool { +func (p *TimestampType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "isAdjustedToUTC", thrift.BOOL, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:isAdjustedToUTC: ", p), err) + } + if err := oprot.WriteBool(ctx, bool(p.IsAdjustedToUTC)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.isAdjustedToUTC (1) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:isAdjustedToUTC: ", p), err) + } + return err +} + +func (p *TimestampType) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "unit", thrift.STRUCT, 2); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:unit: ", p), err) + } + if err := p.Unit.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Unit), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 2:unit: ", p), err) + } + return err +} + +func (p *TimestampType) Equals(other *TimestampType) bool { if p == other { return true } else if p == nil || other == nil { return false } + if p.IsAdjustedToUTC != other.IsAdjustedToUTC { return false } + if !p.Unit.Equals(other.Unit) { return false } return true } -func (p *NanoSeconds) String() string { +func (p *TimestampType) String() string { if p == nil { return "" } - return fmt.Sprintf("NanoSeconds(%+v)", *p) + return fmt.Sprintf("TimestampType(%+v)", *p) } -func (p *NanoSeconds) LogValue() slog.Value { +func (p *TimestampType) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.NanoSeconds", + Type: "*parquet.TimestampType", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*NanoSeconds)(nil) +var _ slog.LogValuer = (*TimestampType)(nil) -func (p *NanoSeconds) Validate() error { +func (p *TimestampType) Validate() error { return nil } +// Time logical type annotation +// +// Allowed for physical types: INT32 (millis), INT64 (micros, nanos) +// // Attributes: -// - MILLIS -// - MICROS -// - NANOS +// - IsAdjustedToUTC +// - Unit // -type TimeUnit struct { - MILLIS *MilliSeconds `thrift:"MILLIS,1" db:"MILLIS" json:"MILLIS,omitempty"` - MICROS *MicroSeconds `thrift:"MICROS,2" db:"MICROS" json:"MICROS,omitempty"` - NANOS *NanoSeconds `thrift:"NANOS,3" db:"NANOS" json:"NANOS,omitempty"` +type TimeType struct { + IsAdjustedToUTC bool `thrift:"isAdjustedToUTC,1,required" db:"isAdjustedToUTC" json:"isAdjustedToUTC"` + Unit *TimeUnit `thrift:"unit,2,required" db:"unit" json:"unit"` } -func NewTimeUnit() *TimeUnit { - return &TimeUnit{} +func NewTimeType() *TimeType { + return &TimeType{} } -var TimeUnit_MILLIS_DEFAULT *MilliSeconds - -func (p *TimeUnit) GetMILLIS() *MilliSeconds { - if !p.IsSetMILLIS() { - return TimeUnit_MILLIS_DEFAULT - } - return p.MILLIS -} -var TimeUnit_MICROS_DEFAULT *MicroSeconds -func (p *TimeUnit) GetMICROS() *MicroSeconds { - if !p.IsSetMICROS() { - return TimeUnit_MICROS_DEFAULT - } - return p.MICROS +func (p *TimeType) GetIsAdjustedToUTC() bool { + return p.IsAdjustedToUTC } -var TimeUnit_NANOS_DEFAULT *NanoSeconds - -func (p *TimeUnit) GetNANOS() *NanoSeconds { - if !p.IsSetNANOS() { - return TimeUnit_NANOS_DEFAULT - } - return p.NANOS -} +var TimeType_Unit_DEFAULT *TimeUnit -func (p *TimeUnit) CountSetFieldsTimeUnit() int { - count := 0 - if (p.IsSetMILLIS()) { - count++ - } - if (p.IsSetMICROS()) { - count++ - } - if (p.IsSetNANOS()) { - count++ +func (p *TimeType) GetUnit() *TimeUnit { + if !p.IsSetUnit() { + return TimeType_Unit_DEFAULT } - return count - -} - -func (p *TimeUnit) IsSetMILLIS() bool { - return p.MILLIS != nil -} - -func (p *TimeUnit) IsSetMICROS() bool { - return p.MICROS != nil + return p.Unit } -func (p *TimeUnit) IsSetNANOS() bool { - return p.NANOS != nil +func (p *TimeType) IsSetUnit() bool { + return p.Unit != nil } -func (p *TimeUnit) Read(ctx context.Context, iprot thrift.TProtocol) error { +func (p *TimeType) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } + var issetIsAdjustedToUTC bool = false; + var issetUnit bool = false; for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) @@ -2597,10 +3791,11 @@ func (p *TimeUnit) Read(ctx context.Context, iprot thrift.TProtocol) error { } switch fieldId { case 1: - if fieldTypeId == thrift.STRUCT { + if fieldTypeId == thrift.BOOL { if err := p.ReadField1(ctx, iprot); err != nil { return err } + issetIsAdjustedToUTC = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err @@ -2611,16 +3806,7 @@ func (p *TimeUnit) Read(ctx context.Context, iprot thrift.TProtocol) error { if err := p.ReadField2(ctx, iprot); err != nil { return err } - } else { - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err - } - } - case 3: - if fieldTypeId == thrift.STRUCT { - if err := p.ReadField3(ctx, iprot); err != nil { - return err - } + issetUnit = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err @@ -2638,44 +3824,39 @@ func (p *TimeUnit) Read(ctx context.Context, iprot thrift.TProtocol) error { if err := iprot.ReadStructEnd(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } - return nil -} - -func (p *TimeUnit) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { - p.MILLIS = &MilliSeconds{} - if err := p.MILLIS.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.MILLIS), err) + if !issetIsAdjustedToUTC{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field IsAdjustedToUTC is not set")); + } + if !issetUnit{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Unit is not set")); } return nil } -func (p *TimeUnit) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { - p.MICROS = &MicroSeconds{} - if err := p.MICROS.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.MICROS), err) +func (p *TimeType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBool(ctx); err != nil { + return thrift.PrependError("error reading field 1: ", err) + } else { + p.IsAdjustedToUTC = v } return nil } -func (p *TimeUnit) ReadField3(ctx context.Context, iprot thrift.TProtocol) error { - p.NANOS = &NanoSeconds{} - if err := p.NANOS.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.NANOS), err) +func (p *TimeType) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { + p.Unit = &TimeUnit{} + if err := p.Unit.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Unit), err) } return nil } -func (p *TimeUnit) Write(ctx context.Context, oprot thrift.TProtocol) error { - if c := p.CountSetFieldsTimeUnit(); c != 1 { - return fmt.Errorf("%T write union: exactly one field must be set (%d set)", p, c) - } - if err := oprot.WriteStructBegin(ctx, "TimeUnit"); err != nil { +func (p *TimeType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "TimeType"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { if err := p.writeField1(ctx, oprot); err != nil { return err } if err := p.writeField2(ctx, oprot); err != nil { return err } - if err := p.writeField3(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -2686,130 +3867,105 @@ func (p *TimeUnit) Write(ctx context.Context, oprot thrift.TProtocol) error { return nil } -func (p *TimeUnit) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetMILLIS() { - if err := oprot.WriteFieldBegin(ctx, "MILLIS", thrift.STRUCT, 1); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:MILLIS: ", p), err) - } - if err := p.MILLIS.Write(ctx, oprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.MILLIS), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 1:MILLIS: ", p), err) - } - } - return err -} - -func (p *TimeUnit) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetMICROS() { - if err := oprot.WriteFieldBegin(ctx, "MICROS", thrift.STRUCT, 2); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:MICROS: ", p), err) - } - if err := p.MICROS.Write(ctx, oprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.MICROS), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 2:MICROS: ", p), err) - } +func (p *TimeType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "isAdjustedToUTC", thrift.BOOL, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:isAdjustedToUTC: ", p), err) + } + if err := oprot.WriteBool(ctx, bool(p.IsAdjustedToUTC)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.isAdjustedToUTC (1) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:isAdjustedToUTC: ", p), err) } return err } -func (p *TimeUnit) writeField3(ctx context.Context, oprot thrift.TProtocol) (err error) { - if p.IsSetNANOS() { - if err := oprot.WriteFieldBegin(ctx, "NANOS", thrift.STRUCT, 3); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 3:NANOS: ", p), err) - } - if err := p.NANOS.Write(ctx, oprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.NANOS), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 3:NANOS: ", p), err) - } +func (p *TimeType) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "unit", thrift.STRUCT, 2); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:unit: ", p), err) + } + if err := p.Unit.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Unit), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 2:unit: ", p), err) } return err } -func (p *TimeUnit) Equals(other *TimeUnit) bool { +func (p *TimeType) Equals(other *TimeType) bool { if p == other { return true } else if p == nil || other == nil { return false } - if !p.MILLIS.Equals(other.MILLIS) { return false } - if !p.MICROS.Equals(other.MICROS) { return false } - if !p.NANOS.Equals(other.NANOS) { return false } + if p.IsAdjustedToUTC != other.IsAdjustedToUTC { return false } + if !p.Unit.Equals(other.Unit) { return false } return true } -func (p *TimeUnit) String() string { +func (p *TimeType) String() string { if p == nil { return "" } - return fmt.Sprintf("TimeUnit(%+v)", *p) + return fmt.Sprintf("TimeType(%+v)", *p) } -func (p *TimeUnit) LogValue() slog.Value { +func (p *TimeType) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.TimeUnit", + Type: "*parquet.TimeType", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*TimeUnit)(nil) +var _ slog.LogValuer = (*TimeType)(nil) -func (p *TimeUnit) Validate() error { +func (p *TimeType) Validate() error { return nil } -// Timestamp logical type annotation +// Integer logical type annotation // -// Allowed for physical types: INT64 +// bitWidth must be 8, 16, 32, or 64. +// +// Allowed for physical types: INT32, INT64 // // Attributes: -// - IsAdjustedToUTC -// - Unit +// - BitWidth +// - IsSigned // -type TimestampType struct { - IsAdjustedToUTC bool `thrift:"isAdjustedToUTC,1,required" db:"isAdjustedToUTC" json:"isAdjustedToUTC"` - Unit *TimeUnit `thrift:"unit,2,required" db:"unit" json:"unit"` +type IntType struct { + BitWidth int8 `thrift:"bitWidth,1,required" db:"bitWidth" json:"bitWidth"` + IsSigned bool `thrift:"isSigned,2,required" db:"isSigned" json:"isSigned"` } -func NewTimestampType() *TimestampType { - return &TimestampType{} +func NewIntType() *IntType { + return &IntType{} } -func (p *TimestampType) GetIsAdjustedToUTC() bool { - return p.IsAdjustedToUTC +func (p *IntType) GetBitWidth() int8 { + return p.BitWidth } -var TimestampType_Unit_DEFAULT *TimeUnit -func (p *TimestampType) GetUnit() *TimeUnit { - if !p.IsSetUnit() { - return TimestampType_Unit_DEFAULT - } - return p.Unit -} -func (p *TimestampType) IsSetUnit() bool { - return p.Unit != nil +func (p *IntType) GetIsSigned() bool { + return p.IsSigned } -func (p *TimestampType) Read(ctx context.Context, iprot thrift.TProtocol) error { +func (p *IntType) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } - var issetIsAdjustedToUTC bool = false; - var issetUnit bool = false; + var issetBitWidth bool = false; + var issetIsSigned bool = false; for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) @@ -2821,22 +3977,22 @@ func (p *TimestampType) Read(ctx context.Context, iprot thrift.TProtocol) error } switch fieldId { case 1: - if fieldTypeId == thrift.BOOL { + if fieldTypeId == thrift.BYTE { if err := p.ReadField1(ctx, iprot); err != nil { return err } - issetIsAdjustedToUTC = true + issetBitWidth = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err } } case 2: - if fieldTypeId == thrift.STRUCT { + if fieldTypeId == thrift.BOOL { if err := p.ReadField2(ctx, iprot); err != nil { return err } - issetUnit = true + issetIsSigned = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err @@ -2854,34 +4010,36 @@ func (p *TimestampType) Read(ctx context.Context, iprot thrift.TProtocol) error if err := iprot.ReadStructEnd(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } - if !issetIsAdjustedToUTC{ - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field IsAdjustedToUTC is not set")); + if !issetBitWidth{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field BitWidth is not set")); } - if !issetUnit{ - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Unit is not set")); + if !issetIsSigned{ + return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field IsSigned is not set")); } return nil } -func (p *TimestampType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBool(ctx); err != nil { +func (p *IntType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadByte(ctx); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { - p.IsAdjustedToUTC = v + temp := int8(v) + p.BitWidth = temp } return nil } -func (p *TimestampType) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { - p.Unit = &TimeUnit{} - if err := p.Unit.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Unit), err) +func (p *IntType) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBool(ctx); err != nil { + return thrift.PrependError("error reading field 2: ", err) + } else { + p.IsSigned = v } return nil } -func (p *TimestampType) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "TimestampType"); err != nil { +func (p *IntType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "IntType"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { @@ -2897,110 +4055,167 @@ func (p *TimestampType) Write(ctx context.Context, oprot thrift.TProtocol) error return nil } -func (p *TimestampType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { - if err := oprot.WriteFieldBegin(ctx, "isAdjustedToUTC", thrift.BOOL, 1); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:isAdjustedToUTC: ", p), err) +func (p *IntType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "bitWidth", thrift.BYTE, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:bitWidth: ", p), err) } - if err := oprot.WriteBool(ctx, bool(p.IsAdjustedToUTC)); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.isAdjustedToUTC (1) field write error: ", p), err) + if err := oprot.WriteByte(ctx, int8(p.BitWidth)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.bitWidth (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 1:isAdjustedToUTC: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:bitWidth: ", p), err) } return err } -func (p *TimestampType) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { - if err := oprot.WriteFieldBegin(ctx, "unit", thrift.STRUCT, 2); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:unit: ", p), err) +func (p *IntType) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "isSigned", thrift.BOOL, 2); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:isSigned: ", p), err) } - if err := p.Unit.Write(ctx, oprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Unit), err) + if err := oprot.WriteBool(ctx, bool(p.IsSigned)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.isSigned (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 2:unit: ", p), err) + return thrift.PrependError(fmt.Sprintf("%T write field end error 2:isSigned: ", p), err) } return err } -func (p *TimestampType) Equals(other *TimestampType) bool { +func (p *IntType) Equals(other *IntType) bool { if p == other { return true } else if p == nil || other == nil { return false } - if p.IsAdjustedToUTC != other.IsAdjustedToUTC { return false } - if !p.Unit.Equals(other.Unit) { return false } + if p.BitWidth != other.BitWidth { return false } + if p.IsSigned != other.IsSigned { return false } return true } -func (p *TimestampType) String() string { +func (p *IntType) String() string { if p == nil { return "" } - return fmt.Sprintf("TimestampType(%+v)", *p) + return fmt.Sprintf("IntType(%+v)", *p) } -func (p *TimestampType) LogValue() slog.Value { +func (p *IntType) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.TimestampType", + Type: "*parquet.IntType", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*TimestampType)(nil) +var _ slog.LogValuer = (*IntType)(nil) -func (p *TimestampType) Validate() error { +func (p *IntType) Validate() error { return nil } -// Time logical type annotation -// -// Allowed for physical types: INT32 (millis), INT64 (micros, nanos) -// -// Attributes: -// - IsAdjustedToUTC -// - Unit +// Embedded JSON logical type annotation // -type TimeType struct { - IsAdjustedToUTC bool `thrift:"isAdjustedToUTC,1,required" db:"isAdjustedToUTC" json:"isAdjustedToUTC"` - Unit *TimeUnit `thrift:"unit,2,required" db:"unit" json:"unit"` +// Allowed for physical types: BYTE_ARRAY +type JsonType struct { } -func NewTimeType() *TimeType { - return &TimeType{} +func NewJsonType() *JsonType { + return &JsonType{} } +func (p *JsonType) Read(ctx context.Context, iprot thrift.TProtocol) error { + if _, err := iprot.ReadStructBegin(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) + } -func (p *TimeType) GetIsAdjustedToUTC() bool { - return p.IsAdjustedToUTC + for { + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) + if err != nil { + return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) + } + if fieldTypeId == thrift.STOP { + break + } + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + if err := iprot.ReadFieldEnd(ctx); err != nil { + return err + } + } + if err := iprot.ReadStructEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + } + return nil } -var TimeType_Unit_DEFAULT *TimeUnit +func (p *JsonType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "JsonType"); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + } + if p != nil { + } + if err := oprot.WriteFieldStop(ctx); err != nil { + return thrift.PrependError("write field stop error: ", err) + } + if err := oprot.WriteStructEnd(ctx); err != nil { + return thrift.PrependError("write struct stop error: ", err) + } + return nil +} + +func (p *JsonType) Equals(other *JsonType) bool { + if p == other { + return true + } else if p == nil || other == nil { + return false + } + return true +} + +func (p *JsonType) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("JsonType(%+v)", *p) +} + +func (p *JsonType) LogValue() slog.Value { + if p == nil { + return slog.AnyValue(nil) + } + v := thrift.SlogTStructWrapper{ + Type: "*parquet.JsonType", + Value: p, + } + return slog.AnyValue(v) +} + +var _ slog.LogValuer = (*JsonType)(nil) + +func (p *JsonType) Validate() error { + return nil +} -func (p *TimeType) GetUnit() *TimeUnit { - if !p.IsSetUnit() { - return TimeType_Unit_DEFAULT - } - return p.Unit +// Embedded BSON logical type annotation +// +// Allowed for physical types: BYTE_ARRAY +type BsonType struct { } -func (p *TimeType) IsSetUnit() bool { - return p.Unit != nil +func NewBsonType() *BsonType { + return &BsonType{} } -func (p *TimeType) Read(ctx context.Context, iprot thrift.TProtocol) error { +func (p *BsonType) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } - var issetIsAdjustedToUTC bool = false; - var issetUnit bool = false; for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) @@ -3010,33 +4225,8 @@ func (p *TimeType) Read(ctx context.Context, iprot thrift.TProtocol) error { if fieldTypeId == thrift.STOP { break } - switch fieldId { - case 1: - if fieldTypeId == thrift.BOOL { - if err := p.ReadField1(ctx, iprot); err != nil { - return err - } - issetIsAdjustedToUTC = true - } else { - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err - } - } - case 2: - if fieldTypeId == thrift.STRUCT { - if err := p.ReadField2(ctx, iprot); err != nil { - return err - } - issetUnit = true - } else { - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err - } - } - default: - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err - } + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err } if err := iprot.ReadFieldEnd(ctx); err != nil { return err @@ -3045,39 +4235,14 @@ func (p *TimeType) Read(ctx context.Context, iprot thrift.TProtocol) error { if err := iprot.ReadStructEnd(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } - if !issetIsAdjustedToUTC{ - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field IsAdjustedToUTC is not set")); - } - if !issetUnit{ - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Unit is not set")); - } - return nil -} - -func (p *TimeType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBool(ctx); err != nil { - return thrift.PrependError("error reading field 1: ", err) - } else { - p.IsAdjustedToUTC = v - } - return nil -} - -func (p *TimeType) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { - p.Unit = &TimeUnit{} - if err := p.Unit.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Unit), err) - } return nil } -func (p *TimeType) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "TimeType"); err != nil { +func (p *BsonType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "BsonType"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { - if err := p.writeField1(ctx, oprot); err != nil { return err } - if err := p.writeField2(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -3088,105 +4253,70 @@ func (p *TimeType) Write(ctx context.Context, oprot thrift.TProtocol) error { return nil } -func (p *TimeType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { - if err := oprot.WriteFieldBegin(ctx, "isAdjustedToUTC", thrift.BOOL, 1); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:isAdjustedToUTC: ", p), err) - } - if err := oprot.WriteBool(ctx, bool(p.IsAdjustedToUTC)); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.isAdjustedToUTC (1) field write error: ", p), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 1:isAdjustedToUTC: ", p), err) - } - return err -} - -func (p *TimeType) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { - if err := oprot.WriteFieldBegin(ctx, "unit", thrift.STRUCT, 2); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:unit: ", p), err) - } - if err := p.Unit.Write(ctx, oprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Unit), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 2:unit: ", p), err) - } - return err -} - -func (p *TimeType) Equals(other *TimeType) bool { +func (p *BsonType) Equals(other *BsonType) bool { if p == other { return true } else if p == nil || other == nil { return false } - if p.IsAdjustedToUTC != other.IsAdjustedToUTC { return false } - if !p.Unit.Equals(other.Unit) { return false } return true } -func (p *TimeType) String() string { +func (p *BsonType) String() string { if p == nil { return "" } - return fmt.Sprintf("TimeType(%+v)", *p) + return fmt.Sprintf("BsonType(%+v)", *p) } -func (p *TimeType) LogValue() slog.Value { +func (p *BsonType) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.TimeType", + Type: "*parquet.BsonType", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*TimeType)(nil) +var _ slog.LogValuer = (*BsonType)(nil) -func (p *TimeType) Validate() error { +func (p *BsonType) Validate() error { return nil } -// Integer logical type annotation -// -// bitWidth must be 8, 16, 32, or 64. -// -// Allowed for physical types: INT32, INT64 +// Embedded Variant logical type annotation // // Attributes: -// - BitWidth -// - IsSigned +// - SpecificationVersion // -type IntType struct { - BitWidth int8 `thrift:"bitWidth,1,required" db:"bitWidth" json:"bitWidth"` - IsSigned bool `thrift:"isSigned,2,required" db:"isSigned" json:"isSigned"` +type VariantType struct { + SpecificationVersion *int8 `thrift:"specification_version,1" db:"specification_version" json:"specification_version,omitempty"` } -func NewIntType() *IntType { - return &IntType{} +func NewVariantType() *VariantType { + return &VariantType{} } +var VariantType_SpecificationVersion_DEFAULT int8 - -func (p *IntType) GetBitWidth() int8 { - return p.BitWidth +func (p *VariantType) GetSpecificationVersion() int8 { + if !p.IsSetSpecificationVersion() { + return VariantType_SpecificationVersion_DEFAULT + } + return *p.SpecificationVersion } - - -func (p *IntType) GetIsSigned() bool { - return p.IsSigned +func (p *VariantType) IsSetSpecificationVersion() bool { + return p.SpecificationVersion != nil } -func (p *IntType) Read(ctx context.Context, iprot thrift.TProtocol) error { +func (p *VariantType) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } - var issetBitWidth bool = false; - var issetIsSigned bool = false; for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) @@ -3202,18 +4332,6 @@ func (p *IntType) Read(ctx context.Context, iprot thrift.TProtocol) error { if err := p.ReadField1(ctx, iprot); err != nil { return err } - issetBitWidth = true - } else { - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err - } - } - case 2: - if fieldTypeId == thrift.BOOL { - if err := p.ReadField2(ctx, iprot); err != nil { - return err - } - issetIsSigned = true } else { if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err @@ -3231,41 +4349,25 @@ func (p *IntType) Read(ctx context.Context, iprot thrift.TProtocol) error { if err := iprot.ReadStructEnd(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } - if !issetBitWidth{ - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field BitWidth is not set")); - } - if !issetIsSigned{ - return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field IsSigned is not set")); - } return nil } -func (p *IntType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { +func (p *VariantType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { if v, err := iprot.ReadByte(ctx); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { temp := int8(v) - p.BitWidth = temp - } - return nil -} - -func (p *IntType) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { - if v, err := iprot.ReadBool(ctx); err != nil { - return thrift.PrependError("error reading field 2: ", err) - } else { - p.IsSigned = v + p.SpecificationVersion = &temp } return nil } -func (p *IntType) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "IntType"); err != nil { +func (p *VariantType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "VariantType"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { if err := p.writeField1(ctx, oprot); err != nil { return err } - if err := p.writeField2(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -3276,78 +4378,98 @@ func (p *IntType) Write(ctx context.Context, oprot thrift.TProtocol) error { return nil } -func (p *IntType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { - if err := oprot.WriteFieldBegin(ctx, "bitWidth", thrift.BYTE, 1); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:bitWidth: ", p), err) - } - if err := oprot.WriteByte(ctx, int8(p.BitWidth)); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.bitWidth (1) field write error: ", p), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 1:bitWidth: ", p), err) - } - return err -} - -func (p *IntType) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { - if err := oprot.WriteFieldBegin(ctx, "isSigned", thrift.BOOL, 2); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:isSigned: ", p), err) - } - if err := oprot.WriteBool(ctx, bool(p.IsSigned)); err != nil { - return thrift.PrependError(fmt.Sprintf("%T.isSigned (2) field write error: ", p), err) - } - if err := oprot.WriteFieldEnd(ctx); err != nil { - return thrift.PrependError(fmt.Sprintf("%T write field end error 2:isSigned: ", p), err) +func (p *VariantType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetSpecificationVersion() { + if err := oprot.WriteFieldBegin(ctx, "specification_version", thrift.BYTE, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:specification_version: ", p), err) + } + if err := oprot.WriteByte(ctx, int8(*p.SpecificationVersion)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.specification_version (1) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:specification_version: ", p), err) + } } return err } -func (p *IntType) Equals(other *IntType) bool { +func (p *VariantType) Equals(other *VariantType) bool { if p == other { return true } else if p == nil || other == nil { return false } - if p.BitWidth != other.BitWidth { return false } - if p.IsSigned != other.IsSigned { return false } + if p.SpecificationVersion != other.SpecificationVersion { + if p.SpecificationVersion == nil || other.SpecificationVersion == nil { + return false + } + if (*p.SpecificationVersion) != (*other.SpecificationVersion) { return false } + } return true } -func (p *IntType) String() string { +func (p *VariantType) String() string { if p == nil { return "" } - return fmt.Sprintf("IntType(%+v)", *p) + return fmt.Sprintf("VariantType(%+v)", *p) } -func (p *IntType) LogValue() slog.Value { +func (p *VariantType) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.IntType", + Type: "*parquet.VariantType", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*IntType)(nil) +var _ slog.LogValuer = (*VariantType)(nil) -func (p *IntType) Validate() error { +func (p *VariantType) Validate() error { return nil } -// Embedded JSON logical type annotation +// Embedded Geometry logical type annotation // -// Allowed for physical types: BYTE_ARRAY -type JsonType struct { +// Geospatial features in the Well-Known Binary (WKB) format and edges interpolation +// is always linear/planar. +// +// A custom CRS can be set by the crs field. If unset, it defaults to "OGC:CRS84", +// which means that the geometries must be stored in longitude, latitude based on +// the WGS84 datum. +// +// Allowed for physical type: BYTE_ARRAY. +// +// See Geospatial.md for details. +// +// Attributes: +// - Crs +// +type GeometryType struct { + Crs *string `thrift:"crs,1" db:"crs" json:"crs,omitempty"` } -func NewJsonType() *JsonType { - return &JsonType{} +func NewGeometryType() *GeometryType { + return &GeometryType{} } -func (p *JsonType) Read(ctx context.Context, iprot thrift.TProtocol) error { +var GeometryType_Crs_DEFAULT string + +func (p *GeometryType) GetCrs() string { + if !p.IsSetCrs() { + return GeometryType_Crs_DEFAULT + } + return *p.Crs +} + +func (p *GeometryType) IsSetCrs() bool { + return p.Crs != nil +} + +func (p *GeometryType) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } @@ -3361,8 +4483,21 @@ func (p *JsonType) Read(ctx context.Context, iprot thrift.TProtocol) error { if fieldTypeId == thrift.STOP { break } - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err := p.ReadField1(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + default: + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } } if err := iprot.ReadFieldEnd(ctx); err != nil { return err @@ -3374,11 +4509,21 @@ func (p *JsonType) Read(ctx context.Context, iprot thrift.TProtocol) error { return nil } -func (p *JsonType) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "JsonType"); err != nil { +func (p *GeometryType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadString(ctx); err != nil { + return thrift.PrependError("error reading field 1: ", err) + } else { + p.Crs = &v + } + return nil +} + +func (p *GeometryType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "GeometryType"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { + if err := p.writeField1(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -3389,50 +4534,116 @@ func (p *JsonType) Write(ctx context.Context, oprot thrift.TProtocol) error { return nil } -func (p *JsonType) Equals(other *JsonType) bool { +func (p *GeometryType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetCrs() { + if err := oprot.WriteFieldBegin(ctx, "crs", thrift.STRING, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:crs: ", p), err) + } + if err := oprot.WriteString(ctx, string(*p.Crs)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.crs (1) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:crs: ", p), err) + } + } + return err +} + +func (p *GeometryType) Equals(other *GeometryType) bool { if p == other { return true } else if p == nil || other == nil { return false } + if p.Crs != other.Crs { + if p.Crs == nil || other.Crs == nil { + return false + } + if (*p.Crs) != (*other.Crs) { return false } + } return true } -func (p *JsonType) String() string { +func (p *GeometryType) String() string { if p == nil { return "" } - return fmt.Sprintf("JsonType(%+v)", *p) + return fmt.Sprintf("GeometryType(%+v)", *p) } -func (p *JsonType) LogValue() slog.Value { +func (p *GeometryType) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.JsonType", + Type: "*parquet.GeometryType", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*JsonType)(nil) +var _ slog.LogValuer = (*GeometryType)(nil) -func (p *JsonType) Validate() error { +func (p *GeometryType) Validate() error { return nil } -// Embedded BSON logical type annotation +// Embedded Geography logical type annotation // -// Allowed for physical types: BYTE_ARRAY -type BsonType struct { +// Geospatial features in the WKB format with an explicit (non-linear/non-planar) +// edges interpolation algorithm. +// +// A custom geographic CRS can be set by the crs field, where longitudes are +// bound by [-180, 180] and latitudes are bound by [-90, 90]. If unset, the CRS +// defaults to "OGC:CRS84". +// +// An optional algorithm can be set to correctly interpret edges interpolation +// of the geometries. If unset, the algorithm defaults to SPHERICAL. +// +// Allowed for physical type: BYTE_ARRAY. +// +// See Geospatial.md for details. +// +// Attributes: +// - Crs +// - Algorithm +// +type GeographyType struct { + Crs *string `thrift:"crs,1" db:"crs" json:"crs,omitempty"` + Algorithm *EdgeInterpolationAlgorithm `thrift:"algorithm,2" db:"algorithm" json:"algorithm,omitempty"` } -func NewBsonType() *BsonType { - return &BsonType{} +func NewGeographyType() *GeographyType { + return &GeographyType{} } -func (p *BsonType) Read(ctx context.Context, iprot thrift.TProtocol) error { +var GeographyType_Crs_DEFAULT string + +func (p *GeographyType) GetCrs() string { + if !p.IsSetCrs() { + return GeographyType_Crs_DEFAULT + } + return *p.Crs +} + +var GeographyType_Algorithm_DEFAULT EdgeInterpolationAlgorithm + +func (p *GeographyType) GetAlgorithm() EdgeInterpolationAlgorithm { + if !p.IsSetAlgorithm() { + return GeographyType_Algorithm_DEFAULT + } + return *p.Algorithm +} + +func (p *GeographyType) IsSetCrs() bool { + return p.Crs != nil +} + +func (p *GeographyType) IsSetAlgorithm() bool { + return p.Algorithm != nil +} + +func (p *GeographyType) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } @@ -3446,8 +4657,31 @@ func (p *BsonType) Read(ctx context.Context, iprot thrift.TProtocol) error { if fieldTypeId == thrift.STOP { break } - if err := iprot.Skip(ctx, fieldTypeId); err != nil { - return err + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err := p.ReadField1(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 2: + if fieldTypeId == thrift.I32 { + if err := p.ReadField2(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + default: + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } } if err := iprot.ReadFieldEnd(ctx); err != nil { return err @@ -3459,11 +4693,32 @@ func (p *BsonType) Read(ctx context.Context, iprot thrift.TProtocol) error { return nil } -func (p *BsonType) Write(ctx context.Context, oprot thrift.TProtocol) error { - if err := oprot.WriteStructBegin(ctx, "BsonType"); err != nil { +func (p *GeographyType) ReadField1(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadString(ctx); err != nil { + return thrift.PrependError("error reading field 1: ", err) + } else { + p.Crs = &v + } + return nil +} + +func (p *GeographyType) ReadField2(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadI32(ctx); err != nil { + return thrift.PrependError("error reading field 2: ", err) + } else { + temp := EdgeInterpolationAlgorithm(v) + p.Algorithm = &temp + } + return nil +} + +func (p *GeographyType) Write(ctx context.Context, oprot thrift.TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "GeographyType"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if p != nil { + if err := p.writeField1(ctx, oprot); err != nil { return err } + if err := p.writeField2(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -3474,36 +4729,78 @@ func (p *BsonType) Write(ctx context.Context, oprot thrift.TProtocol) error { return nil } -func (p *BsonType) Equals(other *BsonType) bool { +func (p *GeographyType) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetCrs() { + if err := oprot.WriteFieldBegin(ctx, "crs", thrift.STRING, 1); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:crs: ", p), err) + } + if err := oprot.WriteString(ctx, string(*p.Crs)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.crs (1) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 1:crs: ", p), err) + } + } + return err +} + +func (p *GeographyType) writeField2(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetAlgorithm() { + if err := oprot.WriteFieldBegin(ctx, "algorithm", thrift.I32, 2); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:algorithm: ", p), err) + } + if err := oprot.WriteI32(ctx, int32(*p.Algorithm)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.algorithm (2) field write error: ", p), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 2:algorithm: ", p), err) + } + } + return err +} + +func (p *GeographyType) Equals(other *GeographyType) bool { if p == other { return true } else if p == nil || other == nil { return false } + if p.Crs != other.Crs { + if p.Crs == nil || other.Crs == nil { + return false + } + if (*p.Crs) != (*other.Crs) { return false } + } + if p.Algorithm != other.Algorithm { + if p.Algorithm == nil || other.Algorithm == nil { + return false + } + if (*p.Algorithm) != (*other.Algorithm) { return false } + } return true } -func (p *BsonType) String() string { +func (p *GeographyType) String() string { if p == nil { return "" } - return fmt.Sprintf("BsonType(%+v)", *p) + return fmt.Sprintf("GeographyType(%+v)", *p) } -func (p *BsonType) LogValue() slog.Value { +func (p *GeographyType) LogValue() slog.Value { if p == nil { return slog.AnyValue(nil) } v := thrift.SlogTStructWrapper{ - Type: "*parquet.BsonType", + Type: "*parquet.GeographyType", Value: p, } return slog.AnyValue(v) } -var _ slog.LogValuer = (*BsonType)(nil) +var _ slog.LogValuer = (*GeographyType)(nil) -func (p *BsonType) Validate() error { +func (p *GeographyType) Validate() error { return nil } @@ -3528,6 +4825,9 @@ func (p *BsonType) Validate() error { // - BSON // - UUID // - FLOAT16 +// - VARIANT +// - GEOMETRY +// - GEOGRAPHY // type LogicalType struct { STRING *StringType `thrift:"STRING,1" db:"STRING" json:"STRING,omitempty"` @@ -3545,6 +4845,9 @@ type LogicalType struct { BSON *BsonType `thrift:"BSON,13" db:"BSON" json:"BSON,omitempty"` UUID *UUIDType `thrift:"UUID,14" db:"UUID" json:"UUID,omitempty"` FLOAT16 *Float16Type `thrift:"FLOAT16,15" db:"FLOAT16" json:"FLOAT16,omitempty"` + VARIANT *VariantType `thrift:"VARIANT,16" db:"VARIANT" json:"VARIANT,omitempty"` + GEOMETRY *GeometryType `thrift:"GEOMETRY,17" db:"GEOMETRY" json:"GEOMETRY,omitempty"` + GEOGRAPHY *GeographyType `thrift:"GEOGRAPHY,18" db:"GEOGRAPHY" json:"GEOGRAPHY,omitempty"` } func NewLogicalType() *LogicalType { @@ -3677,6 +4980,33 @@ func (p *LogicalType) GetFLOAT16() *Float16Type { return p.FLOAT16 } +var LogicalType_VARIANT_DEFAULT *VariantType + +func (p *LogicalType) GetVARIANT() *VariantType { + if !p.IsSetVARIANT() { + return LogicalType_VARIANT_DEFAULT + } + return p.VARIANT +} + +var LogicalType_GEOMETRY_DEFAULT *GeometryType + +func (p *LogicalType) GetGEOMETRY() *GeometryType { + if !p.IsSetGEOMETRY() { + return LogicalType_GEOMETRY_DEFAULT + } + return p.GEOMETRY +} + +var LogicalType_GEOGRAPHY_DEFAULT *GeographyType + +func (p *LogicalType) GetGEOGRAPHY() *GeographyType { + if !p.IsSetGEOGRAPHY() { + return LogicalType_GEOGRAPHY_DEFAULT + } + return p.GEOGRAPHY +} + func (p *LogicalType) CountSetFieldsLogicalType() int { count := 0 if (p.IsSetSTRING()) { @@ -3721,6 +5051,15 @@ func (p *LogicalType) CountSetFieldsLogicalType() int { if (p.IsSetFLOAT16()) { count++ } + if (p.IsSetVARIANT()) { + count++ + } + if (p.IsSetGEOMETRY()) { + count++ + } + if (p.IsSetGEOGRAPHY()) { + count++ + } return count } @@ -3781,6 +5120,18 @@ func (p *LogicalType) IsSetFLOAT16() bool { return p.FLOAT16 != nil } +func (p *LogicalType) IsSetVARIANT() bool { + return p.VARIANT != nil +} + +func (p *LogicalType) IsSetGEOMETRY() bool { + return p.GEOMETRY != nil +} + +func (p *LogicalType) IsSetGEOGRAPHY() bool { + return p.GEOGRAPHY != nil +} + func (p *LogicalType) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) @@ -3936,6 +5287,36 @@ func (p *LogicalType) Read(ctx context.Context, iprot thrift.TProtocol) error { return err } } + case 16: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField16(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 17: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField17(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } + case 18: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField18(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } default: if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err @@ -4063,6 +5444,30 @@ func (p *LogicalType) ReadField15(ctx context.Context, iprot thrift.TProtocol) e return nil } +func (p *LogicalType) ReadField16(ctx context.Context, iprot thrift.TProtocol) error { + p.VARIANT = &VariantType{} + if err := p.VARIANT.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.VARIANT), err) + } + return nil +} + +func (p *LogicalType) ReadField17(ctx context.Context, iprot thrift.TProtocol) error { + p.GEOMETRY = &GeometryType{} + if err := p.GEOMETRY.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.GEOMETRY), err) + } + return nil +} + +func (p *LogicalType) ReadField18(ctx context.Context, iprot thrift.TProtocol) error { + p.GEOGRAPHY = &GeographyType{} + if err := p.GEOGRAPHY.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.GEOGRAPHY), err) + } + return nil +} + func (p *LogicalType) Write(ctx context.Context, oprot thrift.TProtocol) error { if c := p.CountSetFieldsLogicalType(); c != 1 { return fmt.Errorf("%T write union: exactly one field must be set (%d set)", p, c) @@ -4085,6 +5490,9 @@ func (p *LogicalType) Write(ctx context.Context, oprot thrift.TProtocol) error { if err := p.writeField13(ctx, oprot); err != nil { return err } if err := p.writeField14(ctx, oprot); err != nil { return err } if err := p.writeField15(ctx, oprot); err != nil { return err } + if err := p.writeField16(ctx, oprot); err != nil { return err } + if err := p.writeField17(ctx, oprot); err != nil { return err } + if err := p.writeField18(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -4305,6 +5713,51 @@ func (p *LogicalType) writeField15(ctx context.Context, oprot thrift.TProtocol) return err } +func (p *LogicalType) writeField16(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetVARIANT() { + if err := oprot.WriteFieldBegin(ctx, "VARIANT", thrift.STRUCT, 16); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 16:VARIANT: ", p), err) + } + if err := p.VARIANT.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.VARIANT), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 16:VARIANT: ", p), err) + } + } + return err +} + +func (p *LogicalType) writeField17(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetGEOMETRY() { + if err := oprot.WriteFieldBegin(ctx, "GEOMETRY", thrift.STRUCT, 17); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 17:GEOMETRY: ", p), err) + } + if err := p.GEOMETRY.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.GEOMETRY), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 17:GEOMETRY: ", p), err) + } + } + return err +} + +func (p *LogicalType) writeField18(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetGEOGRAPHY() { + if err := oprot.WriteFieldBegin(ctx, "GEOGRAPHY", thrift.STRUCT, 18); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 18:GEOGRAPHY: ", p), err) + } + if err := p.GEOGRAPHY.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.GEOGRAPHY), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 18:GEOGRAPHY: ", p), err) + } + } + return err +} + func (p *LogicalType) Equals(other *LogicalType) bool { if p == other { return true @@ -4325,6 +5778,9 @@ func (p *LogicalType) Equals(other *LogicalType) bool { if !p.BSON.Equals(other.BSON) { return false } if !p.UUID.Equals(other.UUID) { return false } if !p.FLOAT16.Equals(other.FLOAT16) { return false } + if !p.VARIANT.Equals(other.VARIANT) { return false } + if !p.GEOMETRY.Equals(other.GEOMETRY) { return false } + if !p.GEOGRAPHY.Equals(other.GEOGRAPHY) { return false } return true } @@ -8342,6 +9798,7 @@ func (p *PageEncodingStats) Validate() error { // representations. The histograms contained in these statistics can // also be useful in some cases for more fine-grained nullability/list length // filter pushdown. +// - GeospatialStatistics: Optional statistics specific for Geometry and Geography logical types // type ColumnMetaData struct { Type Type `thrift:"type,1,required" db:"type" json:"type"` @@ -8360,6 +9817,7 @@ type ColumnMetaData struct { BloomFilterOffset *int64 `thrift:"bloom_filter_offset,14" db:"bloom_filter_offset" json:"bloom_filter_offset,omitempty"` BloomFilterLength *int32 `thrift:"bloom_filter_length,15" db:"bloom_filter_length" json:"bloom_filter_length,omitempty"` SizeStatistics *SizeStatistics `thrift:"size_statistics,16" db:"size_statistics" json:"size_statistics,omitempty"` + GeospatialStatistics *GeospatialStatistics `thrift:"geospatial_statistics,17" db:"geospatial_statistics" json:"geospatial_statistics,omitempty"` } func NewColumnMetaData() *ColumnMetaData { @@ -8482,6 +9940,15 @@ func (p *ColumnMetaData) GetSizeStatistics() *SizeStatistics { return p.SizeStatistics } +var ColumnMetaData_GeospatialStatistics_DEFAULT *GeospatialStatistics + +func (p *ColumnMetaData) GetGeospatialStatistics() *GeospatialStatistics { + if !p.IsSetGeospatialStatistics() { + return ColumnMetaData_GeospatialStatistics_DEFAULT + } + return p.GeospatialStatistics +} + func (p *ColumnMetaData) IsSetKeyValueMetadata() bool { return p.KeyValueMetadata != nil } @@ -8514,6 +9981,10 @@ func (p *ColumnMetaData) IsSetSizeStatistics() bool { return p.SizeStatistics != nil } +func (p *ColumnMetaData) IsSetGeospatialStatistics() bool { + return p.GeospatialStatistics != nil +} + func (p *ColumnMetaData) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) @@ -8705,6 +10176,16 @@ func (p *ColumnMetaData) Read(ctx context.Context, iprot thrift.TProtocol) error return err } } + case 17: + if fieldTypeId == thrift.STRUCT { + if err := p.ReadField17(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } default: if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err @@ -8762,14 +10243,14 @@ func (p *ColumnMetaData) ReadField2(ctx context.Context, iprot thrift.TProtocol) tSlice := make([]Encoding, 0, size) p.Encodings = tSlice for i := 0; i < size; i++ { - var _elem4 Encoding + var _elem6 Encoding if v, err := iprot.ReadI32(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { temp := Encoding(v) - _elem4 = temp + _elem6 = temp } - p.Encodings = append(p.Encodings, _elem4) + p.Encodings = append(p.Encodings, _elem6) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -8785,13 +10266,13 @@ func (p *ColumnMetaData) ReadField3(ctx context.Context, iprot thrift.TProtocol) tSlice := make([]string, 0, size) p.PathInSchema = tSlice for i := 0; i < size; i++ { - var _elem5 string + var _elem7 string if v, err := iprot.ReadString(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem5 = v + _elem7 = v } - p.PathInSchema = append(p.PathInSchema, _elem5) + p.PathInSchema = append(p.PathInSchema, _elem7) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -8844,11 +10325,11 @@ func (p *ColumnMetaData) ReadField8(ctx context.Context, iprot thrift.TProtocol) tSlice := make([]*KeyValue, 0, size) p.KeyValueMetadata = tSlice for i := 0; i < size; i++ { - _elem6 := &KeyValue{} - if err := _elem6.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem6), err) + _elem8 := &KeyValue{} + if err := _elem8.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem8), err) } - p.KeyValueMetadata = append(p.KeyValueMetadata, _elem6) + p.KeyValueMetadata = append(p.KeyValueMetadata, _elem8) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -8899,11 +10380,11 @@ func (p *ColumnMetaData) ReadField13(ctx context.Context, iprot thrift.TProtocol tSlice := make([]*PageEncodingStats, 0, size) p.EncodingStats = tSlice for i := 0; i < size; i++ { - _elem7 := &PageEncodingStats{} - if err := _elem7.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem7), err) + _elem9 := &PageEncodingStats{} + if err := _elem9.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem9), err) } - p.EncodingStats = append(p.EncodingStats, _elem7) + p.EncodingStats = append(p.EncodingStats, _elem9) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -8937,6 +10418,14 @@ func (p *ColumnMetaData) ReadField16(ctx context.Context, iprot thrift.TProtocol return nil } +func (p *ColumnMetaData) ReadField17(ctx context.Context, iprot thrift.TProtocol) error { + p.GeospatialStatistics = &GeospatialStatistics{} + if err := p.GeospatialStatistics.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.GeospatialStatistics), err) + } + return nil +} + func (p *ColumnMetaData) Write(ctx context.Context, oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin(ctx, "ColumnMetaData"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) @@ -8958,6 +10447,7 @@ func (p *ColumnMetaData) Write(ctx context.Context, oprot thrift.TProtocol) erro if err := p.writeField14(ctx, oprot); err != nil { return err } if err := p.writeField15(ctx, oprot); err != nil { return err } if err := p.writeField16(ctx, oprot); err != nil { return err } + if err := p.writeField17(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) @@ -9224,6 +10714,21 @@ func (p *ColumnMetaData) writeField16(ctx context.Context, oprot thrift.TProtoco return err } +func (p *ColumnMetaData) writeField17(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetGeospatialStatistics() { + if err := oprot.WriteFieldBegin(ctx, "geospatial_statistics", thrift.STRUCT, 17); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 17:geospatial_statistics: ", p), err) + } + if err := p.GeospatialStatistics.Write(ctx, oprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.GeospatialStatistics), err) + } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 17:geospatial_statistics: ", p), err) + } + } + return err +} + func (p *ColumnMetaData) Equals(other *ColumnMetaData) bool { if p == other { return true @@ -9233,13 +10738,13 @@ func (p *ColumnMetaData) Equals(other *ColumnMetaData) bool { if p.Type != other.Type { return false } if len(p.Encodings) != len(other.Encodings) { return false } for i, _tgt := range p.Encodings { - _src8 := other.Encodings[i] - if _tgt != _src8 { return false } + _src10 := other.Encodings[i] + if _tgt != _src10 { return false } } if len(p.PathInSchema) != len(other.PathInSchema) { return false } for i, _tgt := range p.PathInSchema { - _src9 := other.PathInSchema[i] - if _tgt != _src9 { return false } + _src11 := other.PathInSchema[i] + if _tgt != _src11 { return false } } if p.Codec != other.Codec { return false } if p.NumValues != other.NumValues { return false } @@ -9247,8 +10752,8 @@ func (p *ColumnMetaData) Equals(other *ColumnMetaData) bool { if p.TotalCompressedSize != other.TotalCompressedSize { return false } if len(p.KeyValueMetadata) != len(other.KeyValueMetadata) { return false } for i, _tgt := range p.KeyValueMetadata { - _src10 := other.KeyValueMetadata[i] - if !_tgt.Equals(_src10) { return false } + _src12 := other.KeyValueMetadata[i] + if !_tgt.Equals(_src12) { return false } } if p.DataPageOffset != other.DataPageOffset { return false } if p.IndexPageOffset != other.IndexPageOffset { @@ -9266,8 +10771,8 @@ func (p *ColumnMetaData) Equals(other *ColumnMetaData) bool { if !p.Statistics.Equals(other.Statistics) { return false } if len(p.EncodingStats) != len(other.EncodingStats) { return false } for i, _tgt := range p.EncodingStats { - _src11 := other.EncodingStats[i] - if !_tgt.Equals(_src11) { return false } + _src13 := other.EncodingStats[i] + if !_tgt.Equals(_src13) { return false } } if p.BloomFilterOffset != other.BloomFilterOffset { if p.BloomFilterOffset == nil || other.BloomFilterOffset == nil { @@ -9282,6 +10787,7 @@ func (p *ColumnMetaData) Equals(other *ColumnMetaData) bool { if (*p.BloomFilterLength) != (*other.BloomFilterLength) { return false } } if !p.SizeStatistics.Equals(other.SizeStatistics) { return false } + if !p.GeospatialStatistics.Equals(other.GeospatialStatistics) { return false } return true } @@ -9484,13 +10990,13 @@ func (p *EncryptionWithColumnKey) ReadField1(ctx context.Context, iprot thrift.T tSlice := make([]string, 0, size) p.PathInSchema = tSlice for i := 0; i < size; i++ { - var _elem12 string + var _elem14 string if v, err := iprot.ReadString(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem12 = v + _elem14 = v } - p.PathInSchema = append(p.PathInSchema, _elem12) + p.PathInSchema = append(p.PathInSchema, _elem14) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -9568,8 +11074,8 @@ func (p *EncryptionWithColumnKey) Equals(other *EncryptionWithColumnKey) bool { } if len(p.PathInSchema) != len(other.PathInSchema) { return false } for i, _tgt := range p.PathInSchema { - _src13 := other.PathInSchema[i] - if _tgt != _src13 { return false } + _src15 := other.PathInSchema[i] + if _tgt != _src15 { return false } } if bytes.Compare(p.KeyMetadata, other.KeyMetadata) != 0 { return false } return true @@ -10596,11 +12102,11 @@ func (p *RowGroup) ReadField1(ctx context.Context, iprot thrift.TProtocol) error tSlice := make([]*ColumnChunk, 0, size) p.Columns = tSlice for i := 0; i < size; i++ { - _elem14 := &ColumnChunk{} - if err := _elem14.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem14), err) + _elem16 := &ColumnChunk{} + if err := _elem16.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem16), err) } - p.Columns = append(p.Columns, _elem14) + p.Columns = append(p.Columns, _elem16) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -10634,11 +12140,11 @@ func (p *RowGroup) ReadField4(ctx context.Context, iprot thrift.TProtocol) error tSlice := make([]*SortingColumn, 0, size) p.SortingColumns = tSlice for i := 0; i < size; i++ { - _elem15 := &SortingColumn{} - if err := _elem15.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem15), err) + _elem17 := &SortingColumn{} + if err := _elem17.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem17), err) } - p.SortingColumns = append(p.SortingColumns, _elem15) + p.SortingColumns = append(p.SortingColumns, _elem17) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -10818,15 +12324,15 @@ func (p *RowGroup) Equals(other *RowGroup) bool { } if len(p.Columns) != len(other.Columns) { return false } for i, _tgt := range p.Columns { - _src16 := other.Columns[i] - if !_tgt.Equals(_src16) { return false } + _src18 := other.Columns[i] + if !_tgt.Equals(_src18) { return false } } if p.TotalByteSize != other.TotalByteSize { return false } if p.NumRows != other.NumRows { return false } if len(p.SortingColumns) != len(other.SortingColumns) { return false } for i, _tgt := range p.SortingColumns { - _src17 := other.SortingColumns[i] - if !_tgt.Equals(_src17) { return false } + _src19 := other.SortingColumns[i] + if !_tgt.Equals(_src19) { return false } } if p.FileOffset != other.FileOffset { if p.FileOffset == nil || other.FileOffset == nil { @@ -10990,6 +12496,9 @@ func (p *TypeDefinedOrder) Validate() error { // ENUM - unsigned byte-wise comparison // LIST - undefined // MAP - undefined +// VARIANT - undefined +// GEOMETRY - undefined +// GEOGRAPHY - undefined // // In the absence of logical types, the sort order is determined by the physical type: // BOOLEAN - false, true @@ -11497,11 +13006,11 @@ func (p *OffsetIndex) ReadField1(ctx context.Context, iprot thrift.TProtocol) er tSlice := make([]*PageLocation, 0, size) p.PageLocations = tSlice for i := 0; i < size; i++ { - _elem18 := &PageLocation{} - if err := _elem18.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem18), err) + _elem20 := &PageLocation{} + if err := _elem20.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem20), err) } - p.PageLocations = append(p.PageLocations, _elem18) + p.PageLocations = append(p.PageLocations, _elem20) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -11517,13 +13026,13 @@ func (p *OffsetIndex) ReadField2(ctx context.Context, iprot thrift.TProtocol) er tSlice := make([]int64, 0, size) p.UnencodedByteArrayDataBytes = tSlice for i := 0; i < size; i++ { - var _elem19 int64 + var _elem21 int64 if v, err := iprot.ReadI64(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem19 = v + _elem21 = v } - p.UnencodedByteArrayDataBytes = append(p.UnencodedByteArrayDataBytes, _elem19) + p.UnencodedByteArrayDataBytes = append(p.UnencodedByteArrayDataBytes, _elem21) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -11600,13 +13109,13 @@ func (p *OffsetIndex) Equals(other *OffsetIndex) bool { } if len(p.PageLocations) != len(other.PageLocations) { return false } for i, _tgt := range p.PageLocations { - _src20 := other.PageLocations[i] - if !_tgt.Equals(_src20) { return false } + _src22 := other.PageLocations[i] + if !_tgt.Equals(_src22) { return false } } if len(p.UnencodedByteArrayDataBytes) != len(other.UnencodedByteArrayDataBytes) { return false } for i, _tgt := range p.UnencodedByteArrayDataBytes { - _src21 := other.UnencodedByteArrayDataBytes[i] - if _tgt != _src21 { return false } + _src23 := other.UnencodedByteArrayDataBytes[i] + if _tgt != _src23 { return false } } return true } @@ -11884,13 +13393,13 @@ func (p *ColumnIndex) ReadField1(ctx context.Context, iprot thrift.TProtocol) er tSlice := make([]bool, 0, size) p.NullPages = tSlice for i := 0; i < size; i++ { - var _elem22 bool + var _elem24 bool if v, err := iprot.ReadBool(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem22 = v + _elem24 = v } - p.NullPages = append(p.NullPages, _elem22) + p.NullPages = append(p.NullPages, _elem24) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -11906,13 +13415,13 @@ func (p *ColumnIndex) ReadField2(ctx context.Context, iprot thrift.TProtocol) er tSlice := make([][]byte, 0, size) p.MinValues = tSlice for i := 0; i < size; i++ { - var _elem23 []byte + var _elem25 []byte if v, err := iprot.ReadBinary(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem23 = v + _elem25 = v } - p.MinValues = append(p.MinValues, _elem23) + p.MinValues = append(p.MinValues, _elem25) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -11928,13 +13437,13 @@ func (p *ColumnIndex) ReadField3(ctx context.Context, iprot thrift.TProtocol) er tSlice := make([][]byte, 0, size) p.MaxValues = tSlice for i := 0; i < size; i++ { - var _elem24 []byte + var _elem26 []byte if v, err := iprot.ReadBinary(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem24 = v + _elem26 = v } - p.MaxValues = append(p.MaxValues, _elem24) + p.MaxValues = append(p.MaxValues, _elem26) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -11960,13 +13469,13 @@ func (p *ColumnIndex) ReadField5(ctx context.Context, iprot thrift.TProtocol) er tSlice := make([]int64, 0, size) p.NullCounts = tSlice for i := 0; i < size; i++ { - var _elem25 int64 + var _elem27 int64 if v, err := iprot.ReadI64(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem25 = v + _elem27 = v } - p.NullCounts = append(p.NullCounts, _elem25) + p.NullCounts = append(p.NullCounts, _elem27) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -11982,13 +13491,13 @@ func (p *ColumnIndex) ReadField6(ctx context.Context, iprot thrift.TProtocol) er tSlice := make([]int64, 0, size) p.RepetitionLevelHistograms = tSlice for i := 0; i < size; i++ { - var _elem26 int64 + var _elem28 int64 if v, err := iprot.ReadI64(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem26 = v + _elem28 = v } - p.RepetitionLevelHistograms = append(p.RepetitionLevelHistograms, _elem26) + p.RepetitionLevelHistograms = append(p.RepetitionLevelHistograms, _elem28) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -12004,13 +13513,13 @@ func (p *ColumnIndex) ReadField7(ctx context.Context, iprot thrift.TProtocol) er tSlice := make([]int64, 0, size) p.DefinitionLevelHistograms = tSlice for i := 0; i < size; i++ { - var _elem27 int64 + var _elem29 int64 if v, err := iprot.ReadI64(ctx); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { - _elem27 = v + _elem29 = v } - p.DefinitionLevelHistograms = append(p.DefinitionLevelHistograms, _elem27) + p.DefinitionLevelHistograms = append(p.DefinitionLevelHistograms, _elem29) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -12193,34 +13702,34 @@ func (p *ColumnIndex) Equals(other *ColumnIndex) bool { } if len(p.NullPages) != len(other.NullPages) { return false } for i, _tgt := range p.NullPages { - _src28 := other.NullPages[i] - if _tgt != _src28 { return false } + _src30 := other.NullPages[i] + if _tgt != _src30 { return false } } if len(p.MinValues) != len(other.MinValues) { return false } for i, _tgt := range p.MinValues { - _src29 := other.MinValues[i] - if bytes.Compare(_tgt, _src29) != 0 { return false } + _src31 := other.MinValues[i] + if bytes.Compare(_tgt, _src31) != 0 { return false } } if len(p.MaxValues) != len(other.MaxValues) { return false } for i, _tgt := range p.MaxValues { - _src30 := other.MaxValues[i] - if bytes.Compare(_tgt, _src30) != 0 { return false } + _src32 := other.MaxValues[i] + if bytes.Compare(_tgt, _src32) != 0 { return false } } if p.BoundaryOrder != other.BoundaryOrder { return false } if len(p.NullCounts) != len(other.NullCounts) { return false } for i, _tgt := range p.NullCounts { - _src31 := other.NullCounts[i] - if _tgt != _src31 { return false } + _src33 := other.NullCounts[i] + if _tgt != _src33 { return false } } if len(p.RepetitionLevelHistograms) != len(other.RepetitionLevelHistograms) { return false } for i, _tgt := range p.RepetitionLevelHistograms { - _src32 := other.RepetitionLevelHistograms[i] - if _tgt != _src32 { return false } + _src34 := other.RepetitionLevelHistograms[i] + if _tgt != _src34 { return false } } if len(p.DefinitionLevelHistograms) != len(other.DefinitionLevelHistograms) { return false } for i, _tgt := range p.DefinitionLevelHistograms { - _src33 := other.DefinitionLevelHistograms[i] - if _tgt != _src33 { return false } + _src35 := other.DefinitionLevelHistograms[i] + if _tgt != _src35 { return false } } return true } @@ -13228,11 +14737,11 @@ func (p *FileMetaData) ReadField2(ctx context.Context, iprot thrift.TProtocol) e tSlice := make([]*SchemaElement, 0, size) p.Schema = tSlice for i := 0; i < size; i++ { - _elem34 := &SchemaElement{} - if err := _elem34.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem34), err) + _elem36 := &SchemaElement{} + if err := _elem36.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem36), err) } - p.Schema = append(p.Schema, _elem34) + p.Schema = append(p.Schema, _elem36) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -13257,11 +14766,11 @@ func (p *FileMetaData) ReadField4(ctx context.Context, iprot thrift.TProtocol) e tSlice := make([]*RowGroup, 0, size) p.RowGroups = tSlice for i := 0; i < size; i++ { - _elem35 := &RowGroup{} - if err := _elem35.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem35), err) + _elem37 := &RowGroup{} + if err := _elem37.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem37), err) } - p.RowGroups = append(p.RowGroups, _elem35) + p.RowGroups = append(p.RowGroups, _elem37) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -13277,11 +14786,11 @@ func (p *FileMetaData) ReadField5(ctx context.Context, iprot thrift.TProtocol) e tSlice := make([]*KeyValue, 0, size) p.KeyValueMetadata = tSlice for i := 0; i < size; i++ { - _elem36 := &KeyValue{} - if err := _elem36.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem36), err) + _elem38 := &KeyValue{} + if err := _elem38.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem38), err) } - p.KeyValueMetadata = append(p.KeyValueMetadata, _elem36) + p.KeyValueMetadata = append(p.KeyValueMetadata, _elem38) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -13306,11 +14815,11 @@ func (p *FileMetaData) ReadField7(ctx context.Context, iprot thrift.TProtocol) e tSlice := make([]*ColumnOrder, 0, size) p.ColumnOrders = tSlice for i := 0; i < size; i++ { - _elem37 := &ColumnOrder{} - if err := _elem37.Read(ctx, iprot); err != nil { - return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem37), err) + _elem39 := &ColumnOrder{} + if err := _elem39.Read(ctx, iprot); err != nil { + return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem39), err) } - p.ColumnOrders = append(p.ColumnOrders, _elem37) + p.ColumnOrders = append(p.ColumnOrders, _elem39) } if err := iprot.ReadListEnd(ctx); err != nil { return thrift.PrependError("error reading list end: ", err) @@ -13527,19 +15036,19 @@ func (p *FileMetaData) Equals(other *FileMetaData) bool { if p.Version != other.Version { return false } if len(p.Schema) != len(other.Schema) { return false } for i, _tgt := range p.Schema { - _src38 := other.Schema[i] - if !_tgt.Equals(_src38) { return false } + _src40 := other.Schema[i] + if !_tgt.Equals(_src40) { return false } } if p.NumRows != other.NumRows { return false } if len(p.RowGroups) != len(other.RowGroups) { return false } for i, _tgt := range p.RowGroups { - _src39 := other.RowGroups[i] - if !_tgt.Equals(_src39) { return false } + _src41 := other.RowGroups[i] + if !_tgt.Equals(_src41) { return false } } if len(p.KeyValueMetadata) != len(other.KeyValueMetadata) { return false } for i, _tgt := range p.KeyValueMetadata { - _src40 := other.KeyValueMetadata[i] - if !_tgt.Equals(_src40) { return false } + _src42 := other.KeyValueMetadata[i] + if !_tgt.Equals(_src42) { return false } } if p.CreatedBy != other.CreatedBy { if p.CreatedBy == nil || other.CreatedBy == nil { @@ -13549,8 +15058,8 @@ func (p *FileMetaData) Equals(other *FileMetaData) bool { } if len(p.ColumnOrders) != len(other.ColumnOrders) { return false } for i, _tgt := range p.ColumnOrders { - _src41 := other.ColumnOrders[i] - if !_tgt.Equals(_src41) { return false } + _src43 := other.ColumnOrders[i] + if !_tgt.Equals(_src43) { return false } } if !p.EncryptionAlgorithm.Equals(other.EncryptionAlgorithm) { return false } if bytes.Compare(p.FooterSigningKeyMetadata, other.FooterSigningKeyMetadata) != 0 { return false } diff --git a/parquet/metadata/Makefile b/parquet/metadata/Makefile new file mode 100644 index 00000000..cfcc8ef1 --- /dev/null +++ b/parquet/metadata/Makefile @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this converts rotate instructions from "ro[lr] " -> "ro[lr] , 1" for yasm compatibility +PERL_FIXUP_ROTATE=perl -i -pe 's/(ro[rl]\s+\w{2,3})$$/\1, 1/' + +C2GOASM=c2goasm +CC=clang-19 +C_FLAGS=-target x86_64-unknown-none -masm=intel -mno-red-zone -mstackrealign -mllvm -inline-threshold=1000 \ + -fno-asynchronous-unwind-tables -fno-exceptions -fno-rtti -O3 -fno-builtin -ffast-math -fno-jump-tables -I_lib +ASM_FLAGS_AVX2=-mavx2 -mfma +ASM_FLAGS_SSE4=-msse4 +ASM_FLAGS_BMI2=-mbmi2 +ASM_FLAGS_POPCNT=-mpopcnt + +C_FLAGS_NEON=-O3 -fvectorize -mllvm -force-vector-width=16 -fno-asynchronous-unwind-tables -mno-red-zone -mstackrealign -fno-exceptions \ + -fno-rtti -fno-builtin -ffast-math -fno-jump-tables -I_lib + +GO_SOURCES := $(shell find . -path ./_lib -prune -o -name '*.go' -not -name '*_test.go') +ALL_SOURCES := $(shell find . -path ./_lib -prune -o -name '*.go' -name '*.s' -not -name '*_test.go') + +.PHONY: assembly + +INTEL_SOURCES := \ + bloom_filter_block_avx2_amd64.s bloom_filter_block_sse4_amd64.s + +ARM_SOURCES := \ + bloom_filter_block_neon_arm64.s + + +assembly: $(INTEL_SOURCES) + +_lib/bloom_filter_block_avx2_amd64.s: _lib/bloom_filter_block.c + $(CC) -S $(C_FLAGS) $(ASM_FLAGS_AVX2) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@; perl -i -pe 's/mem(cpy|set)/clib·_mem\1(SB)/' $@ + +_lib/bloom_filter_block_sse4_amd64.s: _lib/bloom_filter_block.c + $(CC) -S $(C_FLAGS) $(ASM_FLAGS_SSE4) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@; perl -i -pe 's/mem(cpy|set)/clib·_mem\1(SB)/' $@ + +# neon not supported by c2goasm, will have to do it manually +#_lib/bloom_filter_block_neon.s: _lib/bloom_filter_block.c +# $(CC) -S $(C_FLAGS_NEON) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ + +bloom_filter_block_avx2_amd64.s: _lib/bloom_filter_block_avx2_amd64.s + $(C2GOASM) -a -f $^ $@ + +bloom_filter_block_sse4_amd64.s: _lib/bloom_filter_block_sse4_amd64.s + $(C2GOASM) -a -f $^ $@ + +clean: + rm -f $(INTEL_SOURCES) + rm -f $(addprefix _lib/,$(INTEL_SOURCES)) \ No newline at end of file diff --git a/parquet/metadata/_lib/arch.h b/parquet/metadata/_lib/arch.h new file mode 100644 index 00000000..165d1202 --- /dev/null +++ b/parquet/metadata/_lib/arch.h @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#undef FULL_NAME + +#if defined(__AVX2__) + #define FULL_NAME(x) x##_avx2 +#elif __SSE4_2__ == 1 + #define FULL_NAME(x) x##_sse4 +#elif __SSE3__ == 1 + #define FULL_NAME(x) x##_sse3 +#elif defined(__ARM_NEON) || defined(__ARM_NEON__) + #define FULL_NAME(x) x##_neon +#else + #define FULL_NAME(x) x##_x86 +#endif \ No newline at end of file diff --git a/parquet/metadata/_lib/bloom_filter_block.c b/parquet/metadata/_lib/bloom_filter_block.c new file mode 100644 index 00000000..dced68be --- /dev/null +++ b/parquet/metadata/_lib/bloom_filter_block.c @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "arch.h" +#include +#include + +// algorithms defined in https://github.com/apache/parquet-format/blob/master/BloomFilter.md +// describing the proper definitions for the bloom filter hash functions +// to be compatible with the parquet format + +#define bitsSetPerBlock 8 +static const uint32_t SALT[bitsSetPerBlock] = { + 0x47b6137bU, 0x44974d91U, 0x8824ad5bU, 0xa2b7289dU, + 0x705495c7U, 0x2df1424bU, 0x9efc4947U, 0x5c6bfb31U}; + +#define PREDICT_FALSE(x) (__builtin_expect(!!(x), 0)) + +bool FULL_NAME(check_block)(const uint32_t blocks[], const int len, const uint64_t hash){ + const uint32_t bucket_index = + (uint32_t)(((hash >> 32) * (uint64_t)(len/8)) >> 32); + const uint32_t key = (uint32_t)hash; + + for (int i = 0; i < bitsSetPerBlock; ++i) + { + const uint32_t mask = UINT32_C(0x1) << ((key * SALT[i]) >> 27); + if (PREDICT_FALSE(0 == (blocks[bitsSetPerBlock * bucket_index + i] & mask))) + { + return false; + } + } + return true; +} + +void FULL_NAME(insert_block)(uint32_t blocks[], const int len, const uint64_t hash) { + const uint32_t bucket_index = + (uint32_t)(((hash >> 32) * (uint64_t)(len/8)) >> 32); + const uint32_t key = (uint32_t)hash; + + for (int i = 0; i < bitsSetPerBlock; ++i) + { + const uint32_t mask = UINT32_C(0x1) << ((key * SALT[i]) >> 27); + blocks[bitsSetPerBlock * bucket_index + i] |= mask; + } +} + +void FULL_NAME(insert_bulk)(uint32_t blocks[], const int block_len, const uint64_t hashes[], const int num_hashes) { + for (int i = 0; i < num_hashes; ++i) { + FULL_NAME(insert_block)(blocks, block_len, hashes[i]); + } +} \ No newline at end of file diff --git a/parquet/metadata/_lib/bloom_filter_block_avx2_amd64.s b/parquet/metadata/_lib/bloom_filter_block_avx2_amd64.s new file mode 100644 index 00000000..4bb08777 --- /dev/null +++ b/parquet/metadata/_lib/bloom_filter_block_avx2_amd64.s @@ -0,0 +1,295 @@ + .text + .intel_syntax noprefix + .file "bloom_filter_block.c" + .globl check_block_avx2 # -- Begin function check_block_avx2 + .p2align 4, 0x90 + .type check_block_avx2,@function +check_block_avx2: # @check_block_avx2 +# %bb.0: + push rbp + mov rbp, rsp + and rsp, -8 + # kill: def $esi killed $esi def $rsi + mov rcx, rdx + shr rcx, 32 + lea eax, [rsi + 7] + test esi, esi + cmovns eax, esi + sar eax, 3 + cdqe + imul rax, rcx + shr rax, 29 + and eax, -8 + imul ecx, edx, 1203114875 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax] + bt esi, ecx + jae .LBB0_8 +# %bb.1: + imul ecx, edx, 1150766481 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 4] + bt esi, ecx + jae .LBB0_8 +# %bb.2: + imul ecx, edx, -2010862245 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 8] + bt esi, ecx + jae .LBB0_8 +# %bb.3: + imul ecx, edx, -1565054819 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 12] + bt esi, ecx + jae .LBB0_8 +# %bb.4: + imul ecx, edx, 1884591559 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 16] + bt esi, ecx + jae .LBB0_8 +# %bb.5: + imul ecx, edx, 770785867 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 20] + bt esi, ecx + jae .LBB0_8 +# %bb.6: + imul ecx, edx, -1627633337 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 24] + bt esi, ecx + jae .LBB0_8 +# %bb.7: + imul ecx, edx, 1550580529 + shr ecx, 27 + mov eax, dword ptr [rdi + 4*rax + 28] + bt eax, ecx + setb al + # kill: def $al killed $al killed $eax + mov rsp, rbp + pop rbp + ret +.LBB0_8: + xor eax, eax + # kill: def $al killed $al killed $eax + mov rsp, rbp + pop rbp + ret +.Lfunc_end0: + .size check_block_avx2, .Lfunc_end0-check_block_avx2 + # -- End function + .globl check_bulk_avx2 # -- Begin function check_bulk_avx2 + .p2align 4, 0x90 + .type check_bulk_avx2,@function +check_bulk_avx2: # @check_bulk_avx2 +# %bb.0: + # kill: def $esi killed $esi def $rsi + test r8d, r8d + jle .LBB1_19 +# %bb.1: + push rbp + mov rbp, rsp + push rbx + and rsp, -8 + lea eax, [rsi + 7] + test esi, esi + cmovns eax, esi + sar eax, 3 + cdqe + mov esi, r8d + xor r8d, r8d + .p2align 4, 0x90 +.LBB1_4: # =>This Inner Loop Header: Depth=1 + mov r10, qword ptr [rdx + 8*r8] + mov r9, r10 + shr r9, 32 + imul r9, rax + shr r9, 29 + and r9d, -8 + imul r11d, r10d, 1203114875 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9] + bt ebx, r11d + jae .LBB1_2 +# %bb.5: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, 1150766481 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 4] + bt ebx, r11d + jae .LBB1_2 +# %bb.6: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, -2010862245 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 8] + bt ebx, r11d + jae .LBB1_2 +# %bb.7: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, -1565054819 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 12] + bt ebx, r11d + jae .LBB1_2 +# %bb.8: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, 1884591559 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 16] + bt ebx, r11d + jae .LBB1_2 +# %bb.9: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, 770785867 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 20] + bt ebx, r11d + jae .LBB1_2 +# %bb.10: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, -1627633337 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 24] + bt ebx, r11d + jae .LBB1_2 +# %bb.11: # in Loop: Header=BB1_4 Depth=1 + imul r10d, r10d, 1550580529 + shr r10d, 27 + mov r9d, dword ptr [rdi + 4*r9 + 28] + bt r9d, r10d + setb r9b + mov byte ptr [rcx + r8], r9b + inc r8 + cmp rsi, r8 + jne .LBB1_4 + jmp .LBB1_18 + .p2align 4, 0x90 +.LBB1_2: # in Loop: Header=BB1_4 Depth=1 + xor r9d, r9d + mov byte ptr [rcx + r8], r9b + inc r8 + cmp rsi, r8 + jne .LBB1_4 +.LBB1_18: + # lea rsp, [rbp - 8] + pop rbx + pop rbp +.LBB1_19: + ret +.Lfunc_end1: + .size check_bulk_avx2, .Lfunc_end1-check_bulk_avx2 + # -- End function + .section .rodata.cst32,"aM",@progbits,32 + .p2align 5, 0x0 # -- Begin function insert_block_avx2 +.LCPI2_0: + .long 1203114875 # 0x47b6137b + .long 1150766481 # 0x44974d91 + .long 2284105051 # 0x8824ad5b + .long 2729912477 # 0xa2b7289d + .long 1884591559 # 0x705495c7 + .long 770785867 # 0x2df1424b + .long 2667333959 # 0x9efc4947 + .long 1550580529 # 0x5c6bfb31 + .section .rodata.cst4,"aM",@progbits,4 + .p2align 2, 0x0 +.LCPI2_1: + .long 1 # 0x1 + .text + .globl insert_block_avx2 + .p2align 4, 0x90 + .type insert_block_avx2,@function +insert_block_avx2: # @insert_block_avx2 +# %bb.0: + push rbp + mov rbp, rsp + and rsp, -8 + # kill: def $esi killed $esi def $rsi + vmovd xmm0, edx + shr rdx, 32 + lea eax, [rsi + 7] + test esi, esi + cmovns eax, esi + sar eax, 3 + cdqe + imul rax, rdx + shr rax, 27 + movabs rcx, 17179869152 + vpbroadcastd ymm0, xmm0 + vpmulld ymm0, ymm0, ymmword ptr [rip + .LCPI2_0] + and rcx, rax + vpsrld ymm0, ymm0, 27 + vpbroadcastd ymm1, dword ptr [rip + .LCPI2_1] # ymm1 = [1,1,1,1,1,1,1,1] + vpsllvd ymm0, ymm1, ymm0 + vpor ymm0, ymm0, ymmword ptr [rdi + rcx] + vmovdqu ymmword ptr [rdi + rcx], ymm0 + mov rsp, rbp + pop rbp + vzeroupper + ret +.Lfunc_end2: + .size insert_block_avx2, .Lfunc_end2-insert_block_avx2 + # -- End function + .section .rodata.cst32,"aM",@progbits,32 + .p2align 5, 0x0 # -- Begin function insert_bulk_avx2 +.LCPI3_0: + .long 1203114875 # 0x47b6137b + .long 1150766481 # 0x44974d91 + .long 2284105051 # 0x8824ad5b + .long 2729912477 # 0xa2b7289d + .long 1884591559 # 0x705495c7 + .long 770785867 # 0x2df1424b + .long 2667333959 # 0x9efc4947 + .long 1550580529 # 0x5c6bfb31 + .section .rodata.cst4,"aM",@progbits,4 + .p2align 2, 0x0 +.LCPI3_1: + .long 1 # 0x1 + .text + .globl insert_bulk_avx2 + .p2align 4, 0x90 + .type insert_bulk_avx2,@function +insert_bulk_avx2: # @insert_bulk_avx2 +# %bb.0: + # kill: def $esi killed $esi def $rsi + test ecx, ecx + jle .LBB3_4 +# %bb.1: + push rbp + mov rbp, rsp + and rsp, -8 + lea eax, [rsi + 7] + test esi, esi + cmovns eax, esi + sar eax, 3 + cdqe + mov ecx, ecx + xor esi, esi + movabs r8, 17179869152 + vmovdqa ymm0, ymmword ptr [rip + .LCPI3_0] # ymm0 = [1203114875,1150766481,2284105051,2729912477,1884591559,770785867,2667333959,1550580529] + vpbroadcastd ymm1, dword ptr [rip + .LCPI3_1] # ymm1 = [1,1,1,1,1,1,1,1] + .p2align 4, 0x90 +.LBB3_2: # =>This Inner Loop Header: Depth=1 + mov r9, qword ptr [rdx + 8*rsi] + vmovd xmm2, r9d + shr r9, 32 + imul r9, rax + shr r9, 27 + and r9, r8 + vpbroadcastd ymm2, xmm2 + vpmulld ymm2, ymm2, ymm0 + vpsrld ymm2, ymm2, 27 + vpsllvd ymm2, ymm1, ymm2 + vpor ymm2, ymm2, ymmword ptr [rdi + r9] + vmovdqu ymmword ptr [rdi + r9], ymm2 + inc rsi + cmp rcx, rsi + jne .LBB3_2 +# %bb.3: + mov rsp, rbp + pop rbp +.LBB3_4: + vzeroupper + ret +.Lfunc_end3: + .size insert_bulk_avx2, .Lfunc_end3-insert_bulk_avx2 + # -- End function + .ident "clang version 19.1.6 (https://github.com/conda-forge/clangdev-feedstock a097c63bb6a9919682224023383a143d482c552e)" + .section ".note.GNU-stack","",@progbits + .addrsig \ No newline at end of file diff --git a/parquet/metadata/_lib/bloom_filter_block_sse4_amd64.s b/parquet/metadata/_lib/bloom_filter_block_sse4_amd64.s new file mode 100644 index 00000000..593b61b8 --- /dev/null +++ b/parquet/metadata/_lib/bloom_filter_block_sse4_amd64.s @@ -0,0 +1,322 @@ + .text + .intel_syntax noprefix + .file "bloom_filter_block.c" + .globl check_block_sse4 # -- Begin function check_block_sse4 + .p2align 4, 0x90 + .type check_block_sse4,@function +check_block_sse4: # @check_block_sse4 +# %bb.0: + push rbp + mov rbp, rsp + and rsp, -8 + # kill: def $esi killed $esi def $rsi + mov rcx, rdx + shr rcx, 32 + lea eax, [rsi + 7] + test esi, esi + cmovns eax, esi + sar eax, 3 + cdqe + imul rax, rcx + shr rax, 29 + and eax, -8 + imul ecx, edx, 1203114875 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax] + bt esi, ecx + jae .LBB0_8 +# %bb.1: + imul ecx, edx, 1150766481 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 4] + bt esi, ecx + jae .LBB0_8 +# %bb.2: + imul ecx, edx, -2010862245 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 8] + bt esi, ecx + jae .LBB0_8 +# %bb.3: + imul ecx, edx, -1565054819 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 12] + bt esi, ecx + jae .LBB0_8 +# %bb.4: + imul ecx, edx, 1884591559 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 16] + bt esi, ecx + jae .LBB0_8 +# %bb.5: + imul ecx, edx, 770785867 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 20] + bt esi, ecx + jae .LBB0_8 +# %bb.6: + imul ecx, edx, -1627633337 + shr ecx, 27 + mov esi, dword ptr [rdi + 4*rax + 24] + bt esi, ecx + jae .LBB0_8 +# %bb.7: + imul ecx, edx, 1550580529 + shr ecx, 27 + mov eax, dword ptr [rdi + 4*rax + 28] + bt eax, ecx + setb al + # kill: def $al killed $al killed $eax + mov rsp, rbp + pop rbp + ret +.LBB0_8: + xor eax, eax + # kill: def $al killed $al killed $eax + mov rsp, rbp + pop rbp + ret +.Lfunc_end0: + .size check_block_sse4, .Lfunc_end0-check_block_sse4 + # -- End function + .globl check_bulk_sse4 # -- Begin function check_bulk_sse4 + .p2align 4, 0x90 + .type check_bulk_sse4,@function +check_bulk_sse4: # @check_bulk_sse4 +# %bb.0: + # kill: def $esi killed $esi def $rsi + test r8d, r8d + jle .LBB1_19 +# %bb.1: + push rbp + mov rbp, rsp + push rbx + and rsp, -8 + lea eax, [rsi + 7] + test esi, esi + cmovns eax, esi + sar eax, 3 + cdqe + mov esi, r8d + xor r8d, r8d + .p2align 4, 0x90 +.LBB1_4: # =>This Inner Loop Header: Depth=1 + mov r10, qword ptr [rdx + 8*r8] + mov r9, r10 + shr r9, 32 + imul r9, rax + shr r9, 29 + and r9d, -8 + imul r11d, r10d, 1203114875 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9] + bt ebx, r11d + jae .LBB1_2 +# %bb.5: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, 1150766481 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 4] + bt ebx, r11d + jae .LBB1_2 +# %bb.6: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, -2010862245 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 8] + bt ebx, r11d + jae .LBB1_2 +# %bb.7: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, -1565054819 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 12] + bt ebx, r11d + jae .LBB1_2 +# %bb.8: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, 1884591559 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 16] + bt ebx, r11d + jae .LBB1_2 +# %bb.9: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, 770785867 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 20] + bt ebx, r11d + jae .LBB1_2 +# %bb.10: # in Loop: Header=BB1_4 Depth=1 + imul r11d, r10d, -1627633337 + shr r11d, 27 + mov ebx, dword ptr [rdi + 4*r9 + 24] + bt ebx, r11d + jae .LBB1_2 +# %bb.11: # in Loop: Header=BB1_4 Depth=1 + imul r10d, r10d, 1550580529 + shr r10d, 27 + mov r9d, dword ptr [rdi + 4*r9 + 28] + bt r9d, r10d + setb r9b + mov byte ptr [rcx + r8], r9b + inc r8 + cmp rsi, r8 + jne .LBB1_4 + jmp .LBB1_18 + .p2align 4, 0x90 +.LBB1_2: # in Loop: Header=BB1_4 Depth=1 + xor r9d, r9d + mov byte ptr [rcx + r8], r9b + inc r8 + cmp rsi, r8 + jne .LBB1_4 +.LBB1_18: + # lea rsp, [rbp - 8] + pop rbx + pop rbp +.LBB1_19: + ret +.Lfunc_end1: + .size check_bulk_sse4, .Lfunc_end1-check_bulk_sse4 + # -- End function + .section .rodata.cst16,"aM",@progbits,16 + .p2align 4, 0x0 # -- Begin function insert_block_sse4 +.LCPI2_0: + .long 1203114875 # 0x47b6137b + .long 1150766481 # 0x44974d91 + .long 2284105051 # 0x8824ad5b + .long 2729912477 # 0xa2b7289d +.LCPI2_1: + .long 1065353216 # 0x3f800000 + .long 1065353216 # 0x3f800000 + .long 1065353216 # 0x3f800000 + .long 1065353216 # 0x3f800000 +.LCPI2_2: + .long 1884591559 # 0x705495c7 + .long 770785867 # 0x2df1424b + .long 2667333959 # 0x9efc4947 + .long 1550580529 # 0x5c6bfb31 + .text + .globl insert_block_sse4 + .p2align 4, 0x90 + .type insert_block_sse4,@function +insert_block_sse4: # @insert_block_sse4 +# %bb.0: + push rbp + mov rbp, rsp + and rsp, -8 + # kill: def $esi killed $esi def $rsi + movd xmm0, edx + shr rdx, 32 + lea eax, [rsi + 7] + test esi, esi + cmovns eax, esi + sar eax, 3 + movsxd rcx, eax + imul rcx, rdx + shr rcx, 27 + movabs rax, 17179869152 + and rax, rcx + pshufd xmm0, xmm0, 0 # xmm0 = xmm0[0,0,0,0] + movdqa xmm1, xmmword ptr [rip + .LCPI2_0] # xmm1 = [1203114875,1150766481,2284105051,2729912477] + pmulld xmm1, xmm0 + psrld xmm1, 27 + pslld xmm1, 23 + movdqa xmm2, xmmword ptr [rip + .LCPI2_1] # xmm2 = [1065353216,1065353216,1065353216,1065353216] + paddd xmm1, xmm2 + cvttps2dq xmm1, xmm1 + movups xmm3, xmmword ptr [rdi + rax] + orps xmm3, xmm1 + movups xmm1, xmmword ptr [rdi + rax + 16] + movups xmmword ptr [rdi + rax], xmm3 + pmulld xmm0, xmmword ptr [rip + .LCPI2_2] + psrld xmm0, 27 + pslld xmm0, 23 + paddd xmm0, xmm2 + cvttps2dq xmm0, xmm0 + orps xmm0, xmm1 + movups xmmword ptr [rdi + rax + 16], xmm0 + mov rsp, rbp + pop rbp + ret +.Lfunc_end2: + .size insert_block_sse4, .Lfunc_end2-insert_block_sse4 + # -- End function + .section .rodata.cst16,"aM",@progbits,16 + .p2align 4, 0x0 # -- Begin function insert_bulk_sse4 +.LCPI3_0: + .long 1203114875 # 0x47b6137b + .long 1150766481 # 0x44974d91 + .long 2284105051 # 0x8824ad5b + .long 2729912477 # 0xa2b7289d +.LCPI3_1: + .long 1065353216 # 0x3f800000 + .long 1065353216 # 0x3f800000 + .long 1065353216 # 0x3f800000 + .long 1065353216 # 0x3f800000 +.LCPI3_2: + .long 1884591559 # 0x705495c7 + .long 770785867 # 0x2df1424b + .long 2667333959 # 0x9efc4947 + .long 1550580529 # 0x5c6bfb31 + .text + .globl insert_bulk_sse4 + .p2align 4, 0x90 + .type insert_bulk_sse4,@function +insert_bulk_sse4: # @insert_bulk_sse4 +# %bb.0: + # kill: def $esi killed $esi def $rsi + test ecx, ecx + jle .LBB3_4 +# %bb.1: + push rbp + mov rbp, rsp + and rsp, -8 + lea eax, [rsi + 7] + test esi, esi + cmovns eax, esi + sar eax, 3 + cdqe + mov ecx, ecx + xor esi, esi + movabs r8, 17179869152 + movdqa xmm0, xmmword ptr [rip + .LCPI3_0] # xmm0 = [1203114875,1150766481,2284105051,2729912477] + movdqa xmm1, xmmword ptr [rip + .LCPI3_1] # xmm1 = [1065353216,1065353216,1065353216,1065353216] + movdqa xmm2, xmmword ptr [rip + .LCPI3_2] # xmm2 = [1884591559,770785867,2667333959,1550580529] + .p2align 4, 0x90 +.LBB3_2: # =>This Inner Loop Header: Depth=1 + mov r9, qword ptr [rdx + 8*rsi] + movd xmm3, r9d + shr r9, 32 + imul r9, rax + shr r9, 27 + and r9, r8 + pshufd xmm3, xmm3, 0 # xmm3 = xmm3[0,0,0,0] + movdqa xmm4, xmm3 + pmulld xmm4, xmm0 + psrld xmm4, 27 + pslld xmm4, 23 + paddd xmm4, xmm1 + cvttps2dq xmm4, xmm4 + movups xmm5, xmmword ptr [rdi + r9] + orps xmm5, xmm4 + movups xmm4, xmmword ptr [rdi + r9 + 16] + movups xmmword ptr [rdi + r9], xmm5 + pmulld xmm3, xmm2 + psrld xmm3, 27 + pslld xmm3, 23 + paddd xmm3, xmm1 + cvttps2dq xmm3, xmm3 + orps xmm3, xmm4 + movups xmmword ptr [rdi + r9 + 16], xmm3 + inc rsi + cmp rcx, rsi + jne .LBB3_2 +# %bb.3: + mov rsp, rbp + pop rbp +.LBB3_4: + ret +.Lfunc_end3: + .size insert_bulk_sse4, .Lfunc_end3-insert_bulk_sse4 + # -- End function + .ident "clang version 19.1.6 (https://github.com/conda-forge/clangdev-feedstock a097c63bb6a9919682224023383a143d482c552e)" + .section ".note.GNU-stack","",@progbits + .addrsig \ No newline at end of file diff --git a/parquet/metadata/adaptive_bloom_filter.go b/parquet/metadata/adaptive_bloom_filter.go new file mode 100644 index 00000000..c040645a --- /dev/null +++ b/parquet/metadata/adaptive_bloom_filter.go @@ -0,0 +1,224 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metadata + +import ( + "io" + "runtime" + "slices" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/parquet/internal/encryption" + format "github.com/apache/arrow-go/v18/parquet/internal/gen-go/parquet" + "github.com/apache/arrow-go/v18/parquet/schema" +) + +type bloomFilterCandidate struct { + bloomFilter blockSplitBloomFilter + expectedNDV uint32 +} + +func newBloomFilterCandidate(expectedNDV, numBytes, minBytes, maxBytes uint32, h Hasher, mem memory.Allocator) *bloomFilterCandidate { + if numBytes < minBytes { + numBytes = minBytes + } + + if numBytes > maxBytes { + numBytes = maxBytes + } + + // get next power of 2 if it's not a power of 2 + if (numBytes & (numBytes - 1)) != 0 { + numBytes = uint32(bitutil.NextPowerOf2(int(numBytes))) + } + + buf := memory.NewResizableBuffer(mem) + buf.ResizeNoShrink(int(numBytes)) + bf := blockSplitBloomFilter{ + data: buf, + bitset32: arrow.Uint32Traits.CastFromBytes(buf.Bytes()), + hasher: h, + algorithm: defaultAlgorithm, + hashStrategy: defaultHashStrategy, + compression: defaultCompression, + } + runtime.SetFinalizer(&bf, func(f *blockSplitBloomFilter) { + f.data.Release() + }) + return &bloomFilterCandidate{bloomFilter: bf, expectedNDV: expectedNDV} +} + +type adaptiveBlockSplitBloomFilter struct { + mem memory.Allocator + candidates []*bloomFilterCandidate + largestCandidate *bloomFilterCandidate + numDistinct int64 + finalized bool + + maxBytes, minBytes uint32 + minCandidateNDV int + hasher Hasher + hashStrategy format.BloomFilterHash + algorithm format.BloomFilterAlgorithm + compression format.BloomFilterCompression + + column *schema.Column +} + +func NewAdaptiveBlockSplitBloomFilter(maxBytes uint32, numCandidates int, fpp float64, column *schema.Column, mem memory.Allocator) BloomFilterBuilder { + ret := &adaptiveBlockSplitBloomFilter{ + mem: mem, + maxBytes: min(maximumBloomFilterBytes, maxBytes), + minBytes: minimumBloomFilterBytes, + minCandidateNDV: 16, + hasher: xxhasher{}, + column: column, + hashStrategy: defaultHashStrategy, + algorithm: defaultAlgorithm, + compression: defaultCompression, + } + + ret.initCandidates(maxBytes, numCandidates, fpp) + return ret +} + +func (b *adaptiveBlockSplitBloomFilter) getAlg() *format.BloomFilterAlgorithm { + return &b.algorithm +} + +func (b *adaptiveBlockSplitBloomFilter) getHashStrategy() *format.BloomFilterHash { + return &b.hashStrategy +} + +func (b *adaptiveBlockSplitBloomFilter) getCompression() *format.BloomFilterCompression { + return &b.compression +} + +func (b *adaptiveBlockSplitBloomFilter) optimalCandidate() *bloomFilterCandidate { + return slices.MinFunc(b.candidates, func(a, b *bloomFilterCandidate) int { + return int(b.bloomFilter.Size() - a.bloomFilter.Size()) + }) +} + +func (b *adaptiveBlockSplitBloomFilter) Hasher() Hasher { return b.hasher } + +func (b *adaptiveBlockSplitBloomFilter) InsertHash(hash uint64) { + if b.finalized { + panic("adaptive bloom filter has been marked finalized, no more data allowed") + } + + if !b.largestCandidate.bloomFilter.CheckHash(hash) { + b.numDistinct++ + } + + b.candidates = slices.DeleteFunc(b.candidates, func(c *bloomFilterCandidate) bool { + return c.expectedNDV < uint32(b.numDistinct) && c != b.largestCandidate + }) + + for _, c := range b.candidates { + c.bloomFilter.InsertHash(hash) + } +} + +func (b *adaptiveBlockSplitBloomFilter) InsertBulk(hashes []uint64) { + if b.finalized { + panic("adaptive bloom filter has been marked finalized, no more data allowed") + } + + for _, h := range hashes { + if !b.largestCandidate.bloomFilter.CheckHash(h) { + b.numDistinct++ + } + } + + b.candidates = slices.DeleteFunc(b.candidates, func(c *bloomFilterCandidate) bool { + return c.expectedNDV < uint32(b.numDistinct) && c != b.largestCandidate + }) + + for _, c := range b.candidates { + c.bloomFilter.InsertBulk(hashes) + } +} + +func (b *adaptiveBlockSplitBloomFilter) Size() int64 { + return b.optimalCandidate().bloomFilter.Size() +} + +func (b *adaptiveBlockSplitBloomFilter) CheckHash(hash uint64) bool { + return b.largestCandidate.bloomFilter.CheckHash(hash) +} + +func (b *adaptiveBlockSplitBloomFilter) WriteTo(w io.Writer, enc encryption.Encryptor) (int, error) { + b.finalized = true + + return b.optimalCandidate().bloomFilter.WriteTo(w, enc) +} + +func (b *adaptiveBlockSplitBloomFilter) initCandidates(maxBytes uint32, numCandidates int, fpp float64) { + b.candidates = make([]*bloomFilterCandidate, 0, numCandidates) + candidateByteSize := b.calcBoundedPowerOf2(maxBytes) + for range numCandidates { + candidateExpectedNDV := b.expectedNDV(candidateByteSize, fpp) + if candidateExpectedNDV <= 0 { + break + } + + b.candidates = append(b.candidates, newBloomFilterCandidate(uint32(candidateExpectedNDV), + candidateByteSize, b.minBytes, b.maxBytes, b.hasher, b.mem)) + candidateByteSize = b.calcBoundedPowerOf2(candidateByteSize / 2) + } + + if len(b.candidates) == 0 { + // maxBytes is too small, but at least one candidate will be generated + b.candidates = append(b.candidates, newBloomFilterCandidate(uint32(b.minCandidateNDV), + b.minBytes, b.minBytes, b.maxBytes, b.hasher, b.mem)) + } + + b.largestCandidate = slices.MaxFunc(b.candidates, func(a, b *bloomFilterCandidate) int { + return int(b.bloomFilter.Size() - a.bloomFilter.Size()) + }) +} + +func (b *adaptiveBlockSplitBloomFilter) expectedNDV(numBytes uint32, fpp float64) int { + var ( + expectedNDV, optimalBytes uint32 + ) + + const ndvStep = 500 + for optimalBytes < numBytes { + expectedNDV += ndvStep + optimalBytes = optimalNumBytes(expectedNDV, fpp) + } + + // make sure it is slightly smaller than what numBytes supports + expectedNDV -= ndvStep + return int(max(0, expectedNDV)) +} + +func (b *adaptiveBlockSplitBloomFilter) calcBoundedPowerOf2(numBytes uint32) uint32 { + if numBytes < b.minBytes { + numBytes = b.minBytes + } + + if numBytes&(numBytes-1) != 0 { + numBytes = uint32(bitutil.NextPowerOf2(int(numBytes))) + } + + return max(min(numBytes, b.maxBytes), b.minBytes) +} diff --git a/parquet/metadata/bloom_filter.go b/parquet/metadata/bloom_filter.go new file mode 100644 index 00000000..f0ec8faf --- /dev/null +++ b/parquet/metadata/bloom_filter.go @@ -0,0 +1,565 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metadata + +import ( + "errors" + "fmt" + "io" + "math" + "sync" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/bitutils" + "github.com/apache/arrow-go/v18/parquet" + "github.com/apache/arrow-go/v18/parquet/internal/debug" + "github.com/apache/arrow-go/v18/parquet/internal/encryption" + format "github.com/apache/arrow-go/v18/parquet/internal/gen-go/parquet" + "github.com/apache/arrow-go/v18/parquet/internal/thrift" + "github.com/apache/arrow-go/v18/parquet/internal/utils" + "github.com/apache/arrow-go/v18/parquet/schema" + "github.com/cespare/xxhash/v2" +) + +const ( + bytesPerFilterBlock = 32 + bitsSetPerBlock = 8 + minimumBloomFilterBytes = bytesPerFilterBlock + // currently using 128MB as maximum size, should probably be reconsidered + maximumBloomFilterBytes = 128 * 1024 * 1024 +) + +var ( + salt = [bitsSetPerBlock]uint32{ + 0x47b6137b, 0x44974d91, 0x8824ad5b, 0xa2b7289d, + 0x705495c7, 0x2df1424b, 0x9efc4947, 0x5c6bfb31} + + defaultHashStrategy = format.BloomFilterHash{XXHASH: &format.XxHash{}} + defaultAlgorithm = format.BloomFilterAlgorithm{BLOCK: &format.SplitBlockAlgorithm{}} + defaultCompression = format.BloomFilterCompression{UNCOMPRESSED: &format.Uncompressed{}} +) + +func optimalNumBytes(ndv uint32, fpp float64) uint32 { + optimalBits := optimalNumBits(ndv, fpp) + debug.Assert(bitutil.IsMultipleOf8(int64(optimalBits)), "optimal bits should be multiple of 8") + return optimalBits >> 3 +} + +func optimalNumBits(ndv uint32, fpp float64) uint32 { + debug.Assert(fpp > 0 && fpp < 1, "false positive prob must be in (0, 1)") + var ( + m = -8 * float64(ndv) / math.Log(1-math.Pow(fpp, 1.0/8.0)) + numBits uint32 + ) + + if m < 0 || m > maximumBloomFilterBytes>>3 { + numBits = maximumBloomFilterBytes << 3 + } else { + numBits = uint32(m) + } + + // round up to lower bound + if numBits < minimumBloomFilterBytes<<3 { + numBits = minimumBloomFilterBytes << 3 + } + + // get next power of 2 if bits is not power of 2 + if (numBits & (numBits - 1)) != 0 { + numBits = uint32(bitutil.NextPowerOf2(int(numBits))) + } + return numBits +} + +type Hasher interface { + Sum64(b []byte) uint64 + Sum64s(b [][]byte) []uint64 +} + +type xxhasher struct{} + +func (xxhasher) Sum64(b []byte) uint64 { + return xxhash.Sum64(b) +} + +func (xxhasher) Sum64s(b [][]byte) (vals []uint64) { + vals = make([]uint64, len(b)) + for i, v := range b { + vals[i] = xxhash.Sum64(v) + } + return +} + +func GetHash[T parquet.ColumnTypes](h Hasher, v T) uint64 { + return h.Sum64(getBytes(v)) +} + +func GetHashes[T parquet.ColumnTypes](h Hasher, vals []T) []uint64 { + return h.Sum64s(getBytesSlice(vals)) +} + +func GetSpacedHashes[T parquet.ColumnTypes](h Hasher, numValid int64, vals []T, validBits []byte, validBitsOffset int64) []uint64 { + if numValid == 0 { + return []uint64{} + } + + out := make([]uint64, 0, numValid) + + // TODO: replace with bitset run reader pool + setReader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(vals))) + for { + run := setReader.NextRun() + if run.Length == 0 { + break + } + + out = append(out, h.Sum64s(getBytesSlice(vals[run.Pos:run.Pos+run.Length]))...) + } + return out +} + +func getBytes[T parquet.ColumnTypes](v T) []byte { + switch v := any(v).(type) { + case parquet.ByteArray: + return v + case parquet.FixedLenByteArray: + return v + case parquet.Int96: + return v[:] + } + + return unsafe.Slice((*byte)(unsafe.Pointer(&v)), unsafe.Sizeof(v)) +} + +func getBytesSlice[T parquet.ColumnTypes](v []T) [][]byte { + b := make([][]byte, len(v)) + switch v := any(v).(type) { + case []parquet.ByteArray: + for i, vv := range v { + b[i] = vv + } + return b + case []parquet.FixedLenByteArray: + for i, vv := range v { + b[i] = vv + } + return b + case []parquet.Int96: + for i, vv := range v { + b[i] = vv[:] + } + return b + } + + var z T + sz, ptr := int(unsafe.Sizeof(z)), unsafe.SliceData(v) + raw := unsafe.Slice((*byte)(unsafe.Pointer(ptr)), sz*len(v)) + for i := range b { + b[i] = raw[i*sz : (i+1)*sz] + } + + return b +} + +type blockSplitBloomFilter struct { + data *memory.Buffer + bitset32 []uint32 + + hasher Hasher + algorithm format.BloomFilterAlgorithm + hashStrategy format.BloomFilterHash + compression format.BloomFilterCompression +} + +func (b *blockSplitBloomFilter) getAlg() *format.BloomFilterAlgorithm { + return &b.algorithm +} + +func (b *blockSplitBloomFilter) getHashStrategy() *format.BloomFilterHash { + return &b.hashStrategy +} + +func (b *blockSplitBloomFilter) getCompression() *format.BloomFilterCompression { + return &b.compression +} + +func (b *blockSplitBloomFilter) CheckHash(hash uint64) bool { + return checkHash(b.bitset32, hash) +} + +func (b *blockSplitBloomFilter) CheckBulk(hashes []uint64) []bool { + results := make([]bool, len(hashes)) + checkBulk(b.bitset32, hashes, results) + return results +} + +func (b *blockSplitBloomFilter) InsertHash(hash uint64) { + insertHash(b.bitset32, hash) +} + +func (b *blockSplitBloomFilter) InsertBulk(hashes []uint64) { + insertBulk(b.bitset32, hashes) +} + +func (b *blockSplitBloomFilter) Hasher() Hasher { + return b.hasher +} + +func (b *blockSplitBloomFilter) Size() int64 { + return int64(len(b.bitset32) * 4) +} + +func (b *blockSplitBloomFilter) WriteTo(w io.Writer, enc encryption.Encryptor) (int, error) { + if enc != nil { + n := enc.Encrypt(w, b.data.Bytes()) + return n, nil + } + return w.Write(b.data.Bytes()) +} + +func NewBloomFilter(numBytes, maxBytes uint32, mem memory.Allocator) BloomFilterBuilder { + if numBytes < minimumBloomFilterBytes { + numBytes = minimumBloomFilterBytes + } + + if maxBytes > maximumBloomFilterBytes { + maxBytes = maximumBloomFilterBytes + } + + if numBytes > maxBytes { + numBytes = maxBytes + } + + // get next power of 2 if it's not a power of 2 + if (numBytes & (numBytes - 1)) != 0 { + numBytes = uint32(bitutil.NextPowerOf2(int(numBytes))) + } + + buf := memory.NewResizableBuffer(mem) + buf.ResizeNoShrink(int(numBytes)) + bf := &blockSplitBloomFilter{ + data: buf, + bitset32: arrow.Uint32Traits.CastFromBytes(buf.Bytes()), + hasher: xxhasher{}, + algorithm: format.BloomFilterAlgorithm{BLOCK: &format.SplitBlockAlgorithm{}}, + hashStrategy: format.BloomFilterHash{XXHASH: &format.XxHash{}}, + compression: format.BloomFilterCompression{UNCOMPRESSED: &format.Uncompressed{}}, + } + addCleanup(bf, nil) + return bf +} + +func NewBloomFilterFromNDVAndFPP(ndv uint32, fpp float64, maxBytes int64, mem memory.Allocator) BloomFilterBuilder { + numBytes := optimalNumBytes(ndv, fpp) + if numBytes > uint32(maxBytes) { + numBytes = uint32(maxBytes) + } + + buf := memory.NewResizableBuffer(mem) + buf.ResizeNoShrink(int(numBytes)) + bf := &blockSplitBloomFilter{ + data: buf, + bitset32: arrow.Uint32Traits.CastFromBytes(buf.Bytes()), + hasher: xxhasher{}, + algorithm: format.BloomFilterAlgorithm{BLOCK: &format.SplitBlockAlgorithm{}}, + hashStrategy: format.BloomFilterHash{XXHASH: &format.XxHash{}}, + compression: format.BloomFilterCompression{UNCOMPRESSED: &format.Uncompressed{}}, + } + addCleanup(bf, nil) + return bf +} + +type BloomFilterBuilder interface { + Hasher() Hasher + Size() int64 + InsertHash(hash uint64) + InsertBulk(hashes []uint64) + WriteTo(io.Writer, encryption.Encryptor) (int, error) + + getAlg() *format.BloomFilterAlgorithm + getHashStrategy() *format.BloomFilterHash + getCompression() *format.BloomFilterCompression +} + +type BloomFilter interface { + Hasher() Hasher + CheckHash(hash uint64) bool + Size() int64 +} + +type TypedBloomFilter[T parquet.ColumnTypes] struct { + BloomFilter +} + +func (b *TypedBloomFilter[T]) Check(v T) bool { + h := b.Hasher() + return b.CheckHash(h.Sum64(getBytes(v))) +} + +func validateBloomFilterHeader(hdr *format.BloomFilterHeader) error { + if hdr == nil { + return errors.New("bloom filter header must not be nil") + } + + if !hdr.Algorithm.IsSetBLOCK() { + return fmt.Errorf("unsupported bloom filter algorithm: %s", hdr.Algorithm) + } + + if !hdr.Compression.IsSetUNCOMPRESSED() { + return fmt.Errorf("unsupported bloom filter compression: %s", hdr.Compression) + } + + if !hdr.Hash.IsSetXXHASH() { + return fmt.Errorf("unsupported bloom filter hash strategy: %s", hdr.Hash) + } + + if hdr.NumBytes < minimumBloomFilterBytes || hdr.NumBytes > maximumBloomFilterBytes { + return fmt.Errorf("invalid bloom filter size: %d", hdr.NumBytes) + } + + return nil +} + +type BloomFilterReader struct { + Input parquet.ReaderAtSeeker + FileMetadata *FileMetaData + Props *parquet.ReaderProperties + FileDecryptor encryption.FileDecryptor + BufferPool *sync.Pool +} + +func (r *BloomFilterReader) RowGroup(i int) (*RowGroupBloomFilterReader, error) { + if i < 0 || i >= len(r.FileMetadata.RowGroups) { + return nil, fmt.Errorf("row group index %d out of range", i) + } + + rgMeta := r.FileMetadata.RowGroup(i) + return &RowGroupBloomFilterReader{ + input: r.Input, + rgMeta: rgMeta, + fileDecryptor: r.FileDecryptor, + rgOrdinal: int16(i), + bufferPool: r.BufferPool, + sourceFileSize: r.FileMetadata.sourceFileSize, + }, nil +} + +type RowGroupBloomFilterReader struct { + input parquet.ReaderAtSeeker + rgMeta *RowGroupMetaData + fileDecryptor encryption.FileDecryptor + rgOrdinal int16 + sourceFileSize int64 + + bufferPool *sync.Pool +} + +func (r *RowGroupBloomFilterReader) GetColumnBloomFilter(i int) (BloomFilter, error) { + if i < 0 || i >= r.rgMeta.NumColumns() { + return nil, fmt.Errorf("column index %d out of range", i) + } + + col, err := r.rgMeta.ColumnChunk(i) + if err != nil { + return nil, err + } + + var ( + decryptor encryption.Decryptor + header format.BloomFilterHeader + offset int64 + bloomFilterReadSize int32 = 256 + ) + + if offset = col.BloomFilterOffset(); offset <= 0 { + return nil, nil + } + + if col.BloomFilterLength() > 0 { + bloomFilterReadSize = col.BloomFilterLength() + } + + sectionRdr := io.NewSectionReader(r.input, offset, r.sourceFileSize-offset) + cryptoMetadata := col.CryptoMetadata() + if cryptoMetadata != nil { + decryptor, err = encryption.GetColumnMetaDecryptor(cryptoMetadata, r.fileDecryptor) + if err != nil { + return nil, err + } + + encryption.UpdateDecryptor(decryptor, r.rgOrdinal, int16(i), + encryption.BloomFilterHeaderModule) + hdr := decryptor.DecryptFrom(sectionRdr) + if _, err = thrift.DeserializeThrift(&header, hdr); err != nil { + return nil, err + } + + if err = validateBloomFilterHeader(&header); err != nil { + return nil, err + } + + encryption.UpdateDecryptor(decryptor, r.rgOrdinal, int16(i), + encryption.BloomFilterBitsetModule) + bitset := decryptor.DecryptFrom(sectionRdr) + if len(bitset) != int(header.NumBytes) { + return nil, fmt.Errorf("wrong length of decrypted bloom filter bitset: %d vs %d", + len(bitset), header.NumBytes) + } + + return &blockSplitBloomFilter{ + data: memory.NewBufferBytes(bitset), + bitset32: arrow.Uint32Traits.CastFromBytes(bitset), + hasher: xxhasher{}, + algorithm: *header.Algorithm, + hashStrategy: *header.Hash, + compression: *header.Compression, + }, nil + } + + headerBuf := r.bufferPool.Get().(*memory.Buffer) + headerBuf.ResizeNoShrink(int(bloomFilterReadSize)) + defer func() { + if headerBuf != nil { + headerBuf.ResizeNoShrink(0) + r.bufferPool.Put(headerBuf) + } + }() + + if _, err = sectionRdr.Read(headerBuf.Bytes()); err != nil { + return nil, err + } + + remaining, err := thrift.DeserializeThrift(&header, headerBuf.Bytes()) + if err != nil { + return nil, err + } + headerSize := len(headerBuf.Bytes()) - int(remaining) + + if err = validateBloomFilterHeader(&header); err != nil { + return nil, err + } + + bloomFilterSz := header.NumBytes + var bitset []byte + if int(bloomFilterSz)+headerSize <= len(headerBuf.Bytes()) { + // bloom filter data is entirely contained in the buffer we just read + bitset = headerBuf.Bytes()[headerSize : headerSize+int(bloomFilterSz)] + } else { + buf := r.bufferPool.Get().(*memory.Buffer) + buf.ResizeNoShrink(int(bloomFilterSz)) + filterBytesInHeader := headerBuf.Len() - headerSize + if filterBytesInHeader > 0 { + copy(buf.Bytes(), headerBuf.Bytes()[headerSize:]) + } + + if _, err = sectionRdr.Read(buf.Bytes()[filterBytesInHeader:]); err != nil { + return nil, err + } + bitset = buf.Bytes() + headerBuf.ResizeNoShrink(0) + r.bufferPool.Put(headerBuf) + headerBuf = buf + } + + bf := &blockSplitBloomFilter{ + data: headerBuf, + bitset32: arrow.GetData[uint32](bitset), + hasher: xxhasher{}, + algorithm: *header.Algorithm, + hashStrategy: *header.Hash, + compression: *header.Compression, + } + headerBuf = nil + addCleanup(bf, r.bufferPool) + return bf, nil +} + +type FileBloomFilterBuilder struct { + Schema *schema.Schema + Encryptor encryption.FileEncryptor + + rgMetaBldrs []*RowGroupMetaDataBuilder + bloomFilters []map[string]BloomFilterBuilder +} + +func (f *FileBloomFilterBuilder) AppendRowGroup(rgMeta *RowGroupMetaDataBuilder, filters map[string]BloomFilterBuilder) { + f.rgMetaBldrs = append(f.rgMetaBldrs, rgMeta) + f.bloomFilters = append(f.bloomFilters, filters) +} + +func (f *FileBloomFilterBuilder) WriteTo(w utils.WriterTell) error { + if len(f.rgMetaBldrs) == 0 || len(f.bloomFilters) == 0 { + return nil + } + + var ( + hdr format.BloomFilterHeader + serializer = thrift.NewThriftSerializer() + ) + for rg, rgMeta := range f.rgMetaBldrs { + if len(f.bloomFilters[rg]) == 0 { + continue + } + + for c, col := range rgMeta.colBuilders { + colPath := col.column.Path() + bf, ok := f.bloomFilters[rg][colPath] + if !ok || bf == nil { + continue + } + + offset := w.Tell() + col.chunk.MetaData.BloomFilterOffset = &offset + var encryptor encryption.Encryptor + if f.Encryptor != nil { + encryptor = f.Encryptor.GetColumnMetaEncryptor(colPath) + } + + if encryptor != nil { + encryptor.UpdateAad(encryption.CreateModuleAad( + encryptor.FileAad(), encryption.BloomFilterHeaderModule, + int16(rg), int16(c), encryption.NonPageOrdinal)) + } + + hdr.NumBytes = int32(bf.Size()) + hdr.Algorithm = bf.getAlg() + hdr.Hash = bf.getHashStrategy() + hdr.Compression = bf.getCompression() + + _, err := serializer.Serialize(&hdr, w, encryptor) + if err != nil { + return err + } + + if encryptor != nil { + encryptor.UpdateAad(encryption.CreateModuleAad( + encryptor.FileAad(), encryption.BloomFilterBitsetModule, + int16(rg), int16(c), encryption.NonPageOrdinal)) + } + + if _, err = bf.WriteTo(w, encryptor); err != nil { + return err + } + + dataWritten := int32(w.Tell() - offset) + col.chunk.MetaData.BloomFilterLength = &dataWritten + } + } + return nil +} diff --git a/parquet/metadata/bloom_filter_block.go b/parquet/metadata/bloom_filter_block.go new file mode 100644 index 00000000..bbd84267 --- /dev/null +++ b/parquet/metadata/bloom_filter_block.go @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metadata + +var ( + checkHash func([]uint32, uint64) bool + checkBulk func([]uint32, []uint64, []bool) + insertHash func([]uint32, uint64) + insertBulk func([]uint32, []uint64) +) + +func checkHashGo(bitset32 []uint32, hash uint64) bool { + bucketIdx := uint32(((hash >> 32) * uint64(len(bitset32)/8)) >> 32) + key := uint32(hash) + + for i := range bitsSetPerBlock { + mask := uint32(1) << ((key * salt[i]) >> 27) + if bitset32[bitsSetPerBlock*bucketIdx+uint32(i)]&mask == 0 { + return false + } + } + return true +} + +func insertHashGo(bitset32 []uint32, hash uint64) { + bucketIdx := uint32(((hash >> 32) * uint64(len(bitset32)/8)) >> 32) + key := uint32(hash) + + for i := range bitsSetPerBlock { + mask := uint32(1) << ((key * salt[i]) >> 27) + bitset32[bitsSetPerBlock*bucketIdx+uint32(i)] |= mask + } +} + +func insertBulkGo(bitset32 []uint32, hash []uint64) { + for _, h := range hash { + insertHash(bitset32, h) + } +} diff --git a/parquet/metadata/bloom_filter_block_amd64.go b/parquet/metadata/bloom_filter_block_amd64.go new file mode 100644 index 00000000..463e6967 --- /dev/null +++ b/parquet/metadata/bloom_filter_block_amd64.go @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !noasm + +package metadata + +import ( + "golang.org/x/sys/cpu" +) + +func init() { + if cpu.X86.HasAVX2 { + checkHash = checkBlockAvx2 + insertHash, insertBulk = insertBlockAvx2, insertBulkAvx2 + } else if cpu.X86.HasSSE42 { + checkHash = checkBlockSSE4 + insertHash, insertBulk = insertBlockSSE4, insertBulkSSE4 + } else { + checkHash = checkHashGo + insertHash, insertBulk = insertHashGo, insertBulkGo + } +} diff --git a/parquet/metadata/bloom_filter_block_avx2_amd64.go b/parquet/metadata/bloom_filter_block_avx2_amd64.go new file mode 100644 index 00000000..01655d05 --- /dev/null +++ b/parquet/metadata/bloom_filter_block_avx2_amd64.go @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !noasm + +package metadata + +import ( + "unsafe" +) + +//go:noescape +func _check_block_avx2(bitset32 unsafe.Pointer, len int, hash uint64) (result bool) + +func checkBlockAvx2(bitset32 []uint32, hash uint64) bool { + return _check_block_avx2(unsafe.Pointer(unsafe.SliceData(bitset32)), len(bitset32), hash) +} + +//go:noescape +func _insert_block_avx2(bitset32 unsafe.Pointer, len int, hash uint64) + +func insertBlockAvx2(bitset32 []uint32, hash uint64) { + _insert_block_avx2(unsafe.Pointer(unsafe.SliceData(bitset32)), len(bitset32), hash) +} + +//go:noescape +func _insert_bulk_avx2(bitset32 unsafe.Pointer, block_len int, hashes unsafe.Pointer, hash_len int) + +func insertBulkAvx2(bitset32 []uint32, hashes []uint64) { + _insert_bulk_avx2(unsafe.Pointer(unsafe.SliceData(bitset32)), len(bitset32), + unsafe.Pointer(unsafe.SliceData(hashes)), len(hashes)) +} diff --git a/parquet/metadata/bloom_filter_block_avx2_amd64.s b/parquet/metadata/bloom_filter_block_avx2_amd64.s new file mode 100644 index 00000000..c176800b --- /dev/null +++ b/parquet/metadata/bloom_filter_block_avx2_amd64.s @@ -0,0 +1,151 @@ +//+build !noasm !appengine +// AUTO-GENERATED BY C2GOASM -- DO NOT EDIT + +DATA LCDATA1<>+0x000(SB)/8, $0x44974d9147b6137b +DATA LCDATA1<>+0x008(SB)/8, $0xa2b7289d8824ad5b +DATA LCDATA1<>+0x010(SB)/8, $0x2df1424b705495c7 +DATA LCDATA1<>+0x018(SB)/8, $0x5c6bfb319efc4947 +DATA LCDATA1<>+0x020(SB)/8, $0x0000000000000001 +GLOBL LCDATA1<>(SB), 8, $40 + +TEXT ·_check_block_avx2(SB), $0-32 + + MOVQ bitset32+0(FP), DI + MOVQ len+8(FP), SI + MOVQ hash+16(FP), DX + LEAQ LCDATA1<>(SB), BP + + WORD $0x8948; BYTE $0xd1 // mov rcx, rdx + LONG $0x20e9c148 // shr rcx, 32 + WORD $0x468d; BYTE $0x07 // lea eax, [rsi + 7] + WORD $0xf685 // test esi, esi + WORD $0x490f; BYTE $0xc6 // cmovns eax, esi + WORD $0xf8c1; BYTE $0x03 // sar eax, 3 + WORD $0x9848 // cdqe + LONG $0xc1af0f48 // imul rax, rcx + LONG $0x1de8c148 // shr rax, 29 + WORD $0xe083; BYTE $0xf8 // and eax, -8 + LONG $0x137bca69; WORD $0x47b6 // imul ecx, edx, 1203114875 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + WORD $0x348b; BYTE $0x87 // mov esi, dword [rdi + 4*rax] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x4d91ca69; WORD $0x4497 // imul ecx, edx, 1150766481 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x0487748b // mov esi, dword [rdi + 4*rax + 4] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0xad5bca69; WORD $0x8824 // imul ecx, edx, -2010862245 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x0887748b // mov esi, dword [rdi + 4*rax + 8] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x289dca69; WORD $0xa2b7 // imul ecx, edx, -1565054819 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x0c87748b // mov esi, dword [rdi + 4*rax + 12] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x95c7ca69; WORD $0x7054 // imul ecx, edx, 1884591559 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x1087748b // mov esi, dword [rdi + 4*rax + 16] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x424bca69; WORD $0x2df1 // imul ecx, edx, 770785867 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x1487748b // mov esi, dword [rdi + 4*rax + 20] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x4947ca69; WORD $0x9efc // imul ecx, edx, -1627633337 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x1887748b // mov esi, dword [rdi + 4*rax + 24] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0xfb31ca69; WORD $0x5c6b // imul ecx, edx, 1550580529 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x1c87448b // mov eax, dword [rdi + 4*rax + 28] + WORD $0xa30f; BYTE $0xc8 // bt eax, ecx + WORD $0x920f; BYTE $0xd0 // setb al + MOVQ AX, result+24(FP) + RET + +LBB0_8: + WORD $0xc031 // xor eax, eax + MOVQ AX, result+24(FP) + RET + +TEXT ·_insert_block_avx2(SB), $0-24 + + MOVQ bitset32+0(FP), DI + MOVQ len+8(FP), SI + MOVQ hash+16(FP), DX + LEAQ LCDATA1<>(SB), BP + + LONG $0xc26ef9c5 // vmovd xmm0, edx + LONG $0x20eac148 // shr rdx, 32 + WORD $0x468d; BYTE $0x07 // lea eax, [rsi + 7] + WORD $0xf685 // test esi, esi + WORD $0x490f; BYTE $0xc6 // cmovns eax, esi + WORD $0xf8c1; BYTE $0x03 // sar eax, 3 + WORD $0x9848 // cdqe + LONG $0xc2af0f48 // imul rax, rdx + LONG $0x1be8c148 // shr rax, 27 + QUAD $0x0003ffffffe0b948; WORD $0x0000 // mov rcx, 17179869152 + LONG $0x587de2c4; BYTE $0xc0 // vpbroadcastd ymm0, xmm0 + LONG $0x407de2c4; WORD $0x0045 // vpmulld ymm0, ymm0, yword 0[rbp] /* [rip + .LCPI2_0] */ + WORD $0x2148; BYTE $0xc1 // and rcx, rax + LONG $0xd072fdc5; BYTE $0x1b // vpsrld ymm0, ymm0, 27 + LONG $0x587de2c4; WORD $0x204d // vpbroadcastd ymm1, dword 32[rbp] /* [rip + .LCPI2_1] */ + LONG $0x4775e2c4; BYTE $0xc0 // vpsllvd ymm0, ymm1, ymm0 + LONG $0x04ebfdc5; BYTE $0x0f // vpor ymm0, ymm0, yword [rdi + rcx] + LONG $0x047ffec5; BYTE $0x0f // vmovdqu yword [rdi + rcx], ymm0 + VZEROUPPER + RET + +DATA LCDATA2<>+0x000(SB)/8, $0x44974d9147b6137b +DATA LCDATA2<>+0x008(SB)/8, $0xa2b7289d8824ad5b +DATA LCDATA2<>+0x010(SB)/8, $0x2df1424b705495c7 +DATA LCDATA2<>+0x018(SB)/8, $0x5c6bfb319efc4947 +DATA LCDATA2<>+0x020(SB)/8, $0x0000000000000001 +GLOBL LCDATA2<>(SB), 8, $40 + +TEXT ·_insert_bulk_avx2(SB), $0-32 + + MOVQ bitset32+0(FP), DI + MOVQ block_len+8(FP), SI + MOVQ hashes+16(FP), DX + MOVQ hash_len+24(FP), CX + LEAQ LCDATA2<>(SB), BP + + WORD $0xc985 // test ecx, ecx + JLE LBB3_4 + WORD $0x468d; BYTE $0x07 // lea eax, [rsi + 7] + WORD $0xf685 // test esi, esi + WORD $0x490f; BYTE $0xc6 // cmovns eax, esi + WORD $0xf8c1; BYTE $0x03 // sar eax, 3 + WORD $0x9848 // cdqe + WORD $0xc989 // mov ecx, ecx + WORD $0xf631 // xor esi, esi + QUAD $0x0003ffffffe0b849; WORD $0x0000 // mov r8, 17179869152 + LONG $0x456ffdc5; BYTE $0x00 // vmovdqa ymm0, yword 0[rbp] /* [rip + .LCPI3_0] */ + LONG $0x587de2c4; WORD $0x204d // vpbroadcastd ymm1, dword 32[rbp] /* [rip + .LCPI3_1] */ + +LBB3_2: + LONG $0xf20c8b4c // mov r9, qword [rdx + 8*rsi] + LONG $0x6e79c1c4; BYTE $0xd1 // vmovd xmm2, r9d + LONG $0x20e9c149 // shr r9, 32 + LONG $0xc8af0f4c // imul r9, rax + LONG $0x1be9c149 // shr r9, 27 + WORD $0x214d; BYTE $0xc1 // and r9, r8 + LONG $0x587de2c4; BYTE $0xd2 // vpbroadcastd ymm2, xmm2 + LONG $0x406de2c4; BYTE $0xd0 // vpmulld ymm2, ymm2, ymm0 + LONG $0xd272edc5; BYTE $0x1b // vpsrld ymm2, ymm2, 27 + LONG $0x4775e2c4; BYTE $0xd2 // vpsllvd ymm2, ymm1, ymm2 + LONG $0xeb6da1c4; WORD $0x0f14 // vpor ymm2, ymm2, yword [rdi + r9] + LONG $0x7f7ea1c4; WORD $0x0f14 // vmovdqu yword [rdi + r9], ymm2 + WORD $0xff48; BYTE $0xc6 // inc rsi + WORD $0x3948; BYTE $0xf1 // cmp rcx, rsi + JNE LBB3_2 + +LBB3_4: + VZEROUPPER + RET diff --git a/parquet/metadata/bloom_filter_block_default.go b/parquet/metadata/bloom_filter_block_default.go new file mode 100644 index 00000000..ca1ec750 --- /dev/null +++ b/parquet/metadata/bloom_filter_block_default.go @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build noasm || !amd64 + +package metadata + +func init() { + checkHash, insertHash, insertBulk = checkHashGo, insertHashGo, insertBulkGo +} diff --git a/parquet/metadata/bloom_filter_block_sse4_amd64.go b/parquet/metadata/bloom_filter_block_sse4_amd64.go new file mode 100644 index 00000000..7c94bf73 --- /dev/null +++ b/parquet/metadata/bloom_filter_block_sse4_amd64.go @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !noasm + +package metadata + +import ( + "unsafe" +) + +//go:noescape +func _check_block_sse4(bitset32 unsafe.Pointer, len int, hash uint64) (result bool) + +func checkBlockSSE4(bitset32 []uint32, hash uint64) bool { + return _check_block_sse4(unsafe.Pointer(unsafe.SliceData(bitset32)), len(bitset32), hash) +} + +//go:noescape +func _insert_block_sse4(bitset32 unsafe.Pointer, len int, hash uint64) + +func insertBlockSSE4(bitset32 []uint32, hash uint64) { + _insert_block_sse4(unsafe.Pointer(unsafe.SliceData(bitset32)), len(bitset32), hash) +} + +//go:noescape +func _insert_bulk_sse4(bitset32 unsafe.Pointer, block_len int, hashes unsafe.Pointer, hash_len int) + +func insertBulkSSE4(bitset32 []uint32, hashes []uint64) { + _insert_bulk_sse4(unsafe.Pointer(unsafe.SliceData(bitset32)), len(bitset32), + unsafe.Pointer(unsafe.SliceData(hashes)), len(hashes)) +} diff --git a/parquet/metadata/bloom_filter_block_sse4_amd64.s b/parquet/metadata/bloom_filter_block_sse4_amd64.s new file mode 100644 index 00000000..ac5ff12d --- /dev/null +++ b/parquet/metadata/bloom_filter_block_sse4_amd64.s @@ -0,0 +1,176 @@ +//+build !noasm !appengine +// AUTO-GENERATED BY C2GOASM -- DO NOT EDIT + +DATA LCDATA1<>+0x000(SB)/8, $0x44974d9147b6137b +DATA LCDATA1<>+0x008(SB)/8, $0xa2b7289d8824ad5b +DATA LCDATA1<>+0x010(SB)/8, $0x3f8000003f800000 +DATA LCDATA1<>+0x018(SB)/8, $0x3f8000003f800000 +DATA LCDATA1<>+0x020(SB)/8, $0x2df1424b705495c7 +DATA LCDATA1<>+0x028(SB)/8, $0x5c6bfb319efc4947 +GLOBL LCDATA1<>(SB), 8, $48 + +TEXT ·_check_block_sse4(SB), $0-32 + + MOVQ bitset32+0(FP), DI + MOVQ len+8(FP), SI + MOVQ hash+16(FP), DX + LEAQ LCDATA1<>(SB), BP + + WORD $0x8948; BYTE $0xd1 // mov rcx, rdx + LONG $0x20e9c148 // shr rcx, 32 + WORD $0x468d; BYTE $0x07 // lea eax, [rsi + 7] + WORD $0xf685 // test esi, esi + WORD $0x490f; BYTE $0xc6 // cmovns eax, esi + WORD $0xf8c1; BYTE $0x03 // sar eax, 3 + WORD $0x9848 // cdqe + LONG $0xc1af0f48 // imul rax, rcx + LONG $0x1de8c148 // shr rax, 29 + WORD $0xe083; BYTE $0xf8 // and eax, -8 + LONG $0x137bca69; WORD $0x47b6 // imul ecx, edx, 1203114875 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + WORD $0x348b; BYTE $0x87 // mov esi, dword [rdi + 4*rax] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x4d91ca69; WORD $0x4497 // imul ecx, edx, 1150766481 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x0487748b // mov esi, dword [rdi + 4*rax + 4] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0xad5bca69; WORD $0x8824 // imul ecx, edx, -2010862245 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x0887748b // mov esi, dword [rdi + 4*rax + 8] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x289dca69; WORD $0xa2b7 // imul ecx, edx, -1565054819 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x0c87748b // mov esi, dword [rdi + 4*rax + 12] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x95c7ca69; WORD $0x7054 // imul ecx, edx, 1884591559 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x1087748b // mov esi, dword [rdi + 4*rax + 16] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x424bca69; WORD $0x2df1 // imul ecx, edx, 770785867 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x1487748b // mov esi, dword [rdi + 4*rax + 20] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0x4947ca69; WORD $0x9efc // imul ecx, edx, -1627633337 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x1887748b // mov esi, dword [rdi + 4*rax + 24] + WORD $0xa30f; BYTE $0xce // bt esi, ecx + JAE LBB0_8 + LONG $0xfb31ca69; WORD $0x5c6b // imul ecx, edx, 1550580529 + WORD $0xe9c1; BYTE $0x1b // shr ecx, 27 + LONG $0x1c87448b // mov eax, dword [rdi + 4*rax + 28] + WORD $0xa30f; BYTE $0xc8 // bt eax, ecx + WORD $0x920f; BYTE $0xd0 // setb al + MOVQ AX, result+24(FP) + RET + +LBB0_8: + WORD $0xc031 // xor eax, eax + MOVQ AX, result+24(FP) + RET + +TEXT ·_insert_block_sse4(SB), $0-24 + + MOVQ bitset32+0(FP), DI + MOVQ len+8(FP), SI + MOVQ hash+16(FP), DX + LEAQ LCDATA1<>(SB), BP + + LONG $0xc26e0f66 // movd xmm0, edx + LONG $0x20eac148 // shr rdx, 32 + WORD $0x468d; BYTE $0x07 // lea eax, [rsi + 7] + WORD $0xf685 // test esi, esi + WORD $0x490f; BYTE $0xc6 // cmovns eax, esi + WORD $0xf8c1; BYTE $0x03 // sar eax, 3 + WORD $0x6348; BYTE $0xc8 // movsxd rcx, eax + LONG $0xcaaf0f48 // imul rcx, rdx + LONG $0x1be9c148 // shr rcx, 27 + QUAD $0x0003ffffffe0b848; WORD $0x0000 // mov rax, 17179869152 + WORD $0x2148; BYTE $0xc8 // and rax, rcx + LONG $0xc0700f66; BYTE $0x00 // pshufd xmm0, xmm0, 0 + LONG $0x4d6f0f66; BYTE $0x00 // movdqa xmm1, oword 0[rbp] /* [rip + .LCPI2_0] */ + LONG $0x40380f66; BYTE $0xc8 // pmulld xmm1, xmm0 + LONG $0xd1720f66; BYTE $0x1b // psrld xmm1, 27 + LONG $0xf1720f66; BYTE $0x17 // pslld xmm1, 23 + LONG $0x556f0f66; BYTE $0x10 // movdqa xmm2, oword 16[rbp] /* [rip + .LCPI2_1] */ + LONG $0xcafe0f66 // paddd xmm1, xmm2 + LONG $0xc95b0ff3 // cvttps2dq xmm1, xmm1 + LONG $0x071c100f // movups xmm3, oword [rdi + rax] + WORD $0x560f; BYTE $0xd9 // orps xmm3, xmm1 + LONG $0x074c100f; BYTE $0x10 // movups xmm1, oword [rdi + rax + 16] + LONG $0x071c110f // movups oword [rdi + rax], xmm3 + LONG $0x40380f66; WORD $0x2045 // pmulld xmm0, oword 32[rbp] /* [rip + .LCPI2_2] */ + LONG $0xd0720f66; BYTE $0x1b // psrld xmm0, 27 + LONG $0xf0720f66; BYTE $0x17 // pslld xmm0, 23 + LONG $0xc2fe0f66 // paddd xmm0, xmm2 + LONG $0xc05b0ff3 // cvttps2dq xmm0, xmm0 + WORD $0x560f; BYTE $0xc1 // orps xmm0, xmm1 + LONG $0x0744110f; BYTE $0x10 // movups oword [rdi + rax + 16], xmm0 + RET + +DATA LCDATA2<>+0x000(SB)/8, $0x44974d9147b6137b +DATA LCDATA2<>+0x008(SB)/8, $0xa2b7289d8824ad5b +DATA LCDATA2<>+0x010(SB)/8, $0x3f8000003f800000 +DATA LCDATA2<>+0x018(SB)/8, $0x3f8000003f800000 +DATA LCDATA2<>+0x020(SB)/8, $0x2df1424b705495c7 +DATA LCDATA2<>+0x028(SB)/8, $0x5c6bfb319efc4947 +GLOBL LCDATA2<>(SB), 8, $48 + +TEXT ·_insert_bulk_sse4(SB), $0-32 + + MOVQ bitset32+0(FP), DI + MOVQ block_len+8(FP), SI + MOVQ hashes+16(FP), DX + MOVQ hash_len+24(FP), CX + LEAQ LCDATA2<>(SB), BP + + WORD $0xc985 // test ecx, ecx + JLE LBB3_4 + WORD $0x468d; BYTE $0x07 // lea eax, [rsi + 7] + WORD $0xf685 // test esi, esi + WORD $0x490f; BYTE $0xc6 // cmovns eax, esi + WORD $0xf8c1; BYTE $0x03 // sar eax, 3 + WORD $0x9848 // cdqe + WORD $0xc989 // mov ecx, ecx + WORD $0xf631 // xor esi, esi + QUAD $0x0003ffffffe0b849; WORD $0x0000 // mov r8, 17179869152 + LONG $0x456f0f66; BYTE $0x00 // movdqa xmm0, oword 0[rbp] /* [rip + .LCPI3_0] */ + LONG $0x4d6f0f66; BYTE $0x10 // movdqa xmm1, oword 16[rbp] /* [rip + .LCPI3_1] */ + LONG $0x556f0f66; BYTE $0x20 // movdqa xmm2, oword 32[rbp] /* [rip + .LCPI3_2] */ + +LBB3_2: + LONG $0xf20c8b4c // mov r9, qword [rdx + 8*rsi] + LONG $0x6e0f4166; BYTE $0xd9 // movd xmm3, r9d + LONG $0x20e9c149 // shr r9, 32 + LONG $0xc8af0f4c // imul r9, rax + LONG $0x1be9c149 // shr r9, 27 + WORD $0x214d; BYTE $0xc1 // and r9, r8 + LONG $0xdb700f66; BYTE $0x00 // pshufd xmm3, xmm3, 0 + LONG $0xe36f0f66 // movdqa xmm4, xmm3 + LONG $0x40380f66; BYTE $0xe0 // pmulld xmm4, xmm0 + LONG $0xd4720f66; BYTE $0x1b // psrld xmm4, 27 + LONG $0xf4720f66; BYTE $0x17 // pslld xmm4, 23 + LONG $0xe1fe0f66 // paddd xmm4, xmm1 + LONG $0xe45b0ff3 // cvttps2dq xmm4, xmm4 + LONG $0x2c100f42; BYTE $0x0f // movups xmm5, oword [rdi + r9] + WORD $0x560f; BYTE $0xec // orps xmm5, xmm4 + LONG $0x64100f42; WORD $0x100f // movups xmm4, oword [rdi + r9 + 16] + LONG $0x2c110f42; BYTE $0x0f // movups oword [rdi + r9], xmm5 + LONG $0x40380f66; BYTE $0xda // pmulld xmm3, xmm2 + LONG $0xd3720f66; BYTE $0x1b // psrld xmm3, 27 + LONG $0xf3720f66; BYTE $0x17 // pslld xmm3, 23 + LONG $0xd9fe0f66 // paddd xmm3, xmm1 + LONG $0xdb5b0ff3 // cvttps2dq xmm3, xmm3 + WORD $0x560f; BYTE $0xdc // orps xmm3, xmm4 + LONG $0x5c110f42; WORD $0x100f // movups oword [rdi + r9 + 16], xmm3 + WORD $0xff48; BYTE $0xc6 // inc rsi + WORD $0x3948; BYTE $0xf1 // cmp rcx, rsi + JNE LBB3_2 + +LBB3_4: + RET diff --git a/parquet/metadata/bloom_filter_reader_test.go b/parquet/metadata/bloom_filter_reader_test.go new file mode 100644 index 00000000..432b90f6 --- /dev/null +++ b/parquet/metadata/bloom_filter_reader_test.go @@ -0,0 +1,349 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metadata_test + +import ( + "bytes" + "os" + "runtime" + "sync" + "testing" + + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/parquet" + "github.com/apache/arrow-go/v18/parquet/file" + "github.com/apache/arrow-go/v18/parquet/internal/encryption" + "github.com/apache/arrow-go/v18/parquet/internal/utils" + "github.com/apache/arrow-go/v18/parquet/metadata" + "github.com/apache/arrow-go/v18/parquet/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type BloomFilterBuilderSuite struct { + suite.Suite + + sc *schema.Schema + props *parquet.WriterProperties + mem *memory.CheckedAllocator + buf bytes.Buffer +} + +func (suite *BloomFilterBuilderSuite) SetupTest() { + suite.props = parquet.NewWriterProperties() + suite.mem = memory.NewCheckedAllocator(memory.NewGoAllocator()) + suite.buf.Reset() +} + +func (suite *BloomFilterBuilderSuite) TearDownTest() { + runtime.GC() // we use setfinalizer to clean up the buffers, so run the GC + suite.mem.AssertSize(suite.T(), 0) +} + +func (suite *BloomFilterBuilderSuite) TestSingleRowGroup() { + suite.sc = schema.NewSchema(schema.MustGroup( + schema.NewGroupNode("schema", parquet.Repetitions.Repeated, + schema.FieldList{ + schema.NewByteArrayNode("c1", parquet.Repetitions.Optional, -1), + schema.NewByteArrayNode("c2", parquet.Repetitions.Optional, -1), + schema.NewByteArrayNode("c3", parquet.Repetitions.Optional, -1), + }, -1))) + + bldr := metadata.FileBloomFilterBuilder{Schema: suite.sc} + + metaBldr := metadata.NewFileMetadataBuilder(suite.sc, suite.props, nil) + + { + rgMeta := metaBldr.AppendRowGroup() + filterMap := make(map[string]metadata.BloomFilterBuilder) + bldr.AppendRowGroup(rgMeta, filterMap) + + bf1 := metadata.NewBloomFilter(32, 1024, suite.mem) + bf2 := metadata.NewAdaptiveBlockSplitBloomFilter(1024, 5, 0.01, suite.sc.Column(2), suite.mem) + + h1, h2 := bf1.Hasher(), bf2.Hasher() + + rgMeta.NextColumnChunk() + rgMeta.NextColumnChunk() + rgMeta.NextColumnChunk() + rgMeta.Finish(0, 0) + + bf1.InsertHash(metadata.GetHash(h1, parquet.ByteArray("Hello"))) + bf2.InsertHash(metadata.GetHash(h2, parquet.ByteArray("World"))) + filterMap["c1"] = bf1 + filterMap["c3"] = bf2 + + wr := &utils.TellWrapper{Writer: &suite.buf} + wr.Write([]byte("PAR1")) // offset of 0 means unset, so write something + // to force the offset to be set as a non-zero value + suite.Require().NoError(bldr.WriteTo(wr)) + } + runtime.GC() + + finalMeta, err := metaBldr.Finish() + suite.Require().NoError(err) + { + bufferPool := &sync.Pool{ + New: func() interface{} { + buf := memory.NewResizableBuffer(suite.mem) + runtime.SetFinalizer(buf, func(obj *memory.Buffer) { + obj.Release() + }) + return buf + }, + } + + rdr := metadata.BloomFilterReader{ + Input: bytes.NewReader(suite.buf.Bytes()), + FileMetadata: finalMeta, + BufferPool: bufferPool, + } + + bfr, err := rdr.RowGroup(0) + suite.Require().NoError(err) + suite.Require().NotNil(bfr) + + { + bf1, err := bfr.GetColumnBloomFilter(0) + suite.Require().NoError(err) + suite.Require().NotNil(bf1) + suite.False(bf1.CheckHash(metadata.GetHash(bf1.Hasher(), parquet.ByteArray("World")))) + suite.True(bf1.CheckHash(metadata.GetHash(bf1.Hasher(), parquet.ByteArray("Hello")))) + } + runtime.GC() // force GC to run to put the buffer back into the pool + { + bf2, err := bfr.GetColumnBloomFilter(1) + suite.Require().NoError(err) + suite.Require().Nil(bf2) + } + { + bf3, err := bfr.GetColumnBloomFilter(2) + suite.Require().NoError(err) + suite.Require().NotNil(bf3) + suite.False(bf3.CheckHash(metadata.GetHash(bf3.Hasher(), parquet.ByteArray("Hello")))) + suite.True(bf3.CheckHash(metadata.GetHash(bf3.Hasher(), parquet.ByteArray("World")))) + } + runtime.GC() // we're using setfinalizer, so force release + } + runtime.GC() +} + +const ( + FooterEncryptionKey = "0123456789012345" + ColumnEncryptionKey1 = "1234567890123450" + ColumnEncryptionKey2 = "1234567890123451" + FooterEncryptionKeyID = "kf" + ColumnEncryptionKey1ID = "kc1" + ColumnEncryptionKey2ID = "kc2" +) + +type EncryptedBloomFilterBuilderSuite struct { + suite.Suite + + sc *schema.Schema + props *parquet.WriterProperties + decryptProps *parquet.FileDecryptionProperties + mem *memory.CheckedAllocator + buf bytes.Buffer +} + +func (suite *EncryptedBloomFilterBuilderSuite) SetupTest() { + encryptedCols := parquet.ColumnPathToEncryptionPropsMap{ + "c1": parquet.NewColumnEncryptionProperties("c1", + parquet.WithKey(ColumnEncryptionKey1), parquet.WithKeyID(ColumnEncryptionKey1ID)), + "c2": parquet.NewColumnEncryptionProperties("c2", + parquet.WithKey(ColumnEncryptionKey2), parquet.WithKeyID(ColumnEncryptionKey2ID)), + } + + encProps := parquet.NewFileEncryptionProperties(FooterEncryptionKey, + parquet.WithFooterKeyID(FooterEncryptionKeyID), + parquet.WithEncryptedColumns(encryptedCols)) + + suite.decryptProps = parquet.NewFileDecryptionProperties( + parquet.WithFooterKey(FooterEncryptionKey), + parquet.WithColumnKeys(parquet.ColumnPathToDecryptionPropsMap{ + "c1": parquet.NewColumnDecryptionProperties("c1", parquet.WithDecryptKey(ColumnEncryptionKey1)), + "c2": parquet.NewColumnDecryptionProperties("c2", parquet.WithDecryptKey(ColumnEncryptionKey2)), + })) + + suite.props = parquet.NewWriterProperties(parquet.WithEncryptionProperties(encProps)) + suite.mem = memory.NewCheckedAllocator(memory.NewGoAllocator()) + suite.buf.Reset() +} + +func (suite *EncryptedBloomFilterBuilderSuite) TearDownTest() { + runtime.GC() // we use setfinalizer to clean up the buffers, so run the GC + suite.mem.AssertSize(suite.T(), 0) +} + +func (suite *EncryptedBloomFilterBuilderSuite) TestEncryptedBloomFilters() { + suite.sc = schema.NewSchema(schema.MustGroup( + schema.NewGroupNode("schema", parquet.Repetitions.Repeated, + schema.FieldList{ + schema.NewByteArrayNode("c1", parquet.Repetitions.Optional, -1), + schema.NewByteArrayNode("c2", parquet.Repetitions.Optional, -1), + schema.NewByteArrayNode("c3", parquet.Repetitions.Optional, -1), + }, -1))) + + encryptor := encryption.NewFileEncryptor(suite.props.FileEncryptionProperties(), suite.mem) + metaBldr := metadata.NewFileMetadataBuilder(suite.sc, suite.props, nil) + metaBldr.SetFileEncryptor(encryptor) + bldr := metadata.FileBloomFilterBuilder{Schema: suite.sc, Encryptor: encryptor} + { + rgMeta := metaBldr.AppendRowGroup() + filterMap := make(map[string]metadata.BloomFilterBuilder) + bldr.AppendRowGroup(rgMeta, filterMap) + + bf1 := metadata.NewBloomFilter(32, 1024, suite.mem) + bf2 := metadata.NewAdaptiveBlockSplitBloomFilter(1024, 5, 0.01, suite.sc.Column(1), suite.mem) + h1, h2 := bf1.Hasher(), bf2.Hasher() + + bf1.InsertHash(metadata.GetHash(h1, parquet.ByteArray("Hello"))) + bf2.InsertHash(metadata.GetHash(h2, parquet.ByteArray("World"))) + filterMap["c1"] = bf1 + filterMap["c2"] = bf2 + + colChunk1 := rgMeta.NextColumnChunk() + colChunk1.Finish(metadata.ChunkMetaInfo{}, false, false, metadata.EncodingStats{}) + + colChunk2 := rgMeta.NextColumnChunk() + colChunk2.Finish(metadata.ChunkMetaInfo{}, false, false, metadata.EncodingStats{}) + + colChunk3 := rgMeta.NextColumnChunk() + colChunk3.Finish(metadata.ChunkMetaInfo{}, false, false, metadata.EncodingStats{}) + + wr := &utils.TellWrapper{Writer: &suite.buf} + wr.Write([]byte("PAR1")) // offset of 0 means unset, so write something + // to force the offset to be set as a non-zero value + suite.Require().NoError(bldr.WriteTo(wr)) + + rgMeta.Finish(0, 0) + } + + finalMeta, err := metaBldr.Finish() + suite.Require().NoError(err) + finalMeta.FileDecryptor = encryption.NewFileDecryptor(suite.decryptProps, + suite.props.FileEncryptionProperties().FileAad(), + suite.props.FileEncryptionProperties().Algorithm().Algo, "", suite.mem) + { + bufferPool := &sync.Pool{ + New: func() interface{} { + buf := memory.NewResizableBuffer(suite.mem) + runtime.SetFinalizer(buf, func(obj *memory.Buffer) { + obj.Release() + }) + return buf + }, + } + defer runtime.GC() + + rdr := metadata.BloomFilterReader{ + Input: bytes.NewReader(suite.buf.Bytes()), + FileMetadata: finalMeta, + BufferPool: bufferPool, + FileDecryptor: finalMeta.FileDecryptor, + } + + bfr, err := rdr.RowGroup(0) + suite.Require().NoError(err) + suite.Require().NotNil(bfr) + + { + bf1, err := bfr.GetColumnBloomFilter(0) + suite.Require().NoError(err) + suite.Require().NotNil(bf1) + suite.False(bf1.CheckHash(metadata.GetHash(bf1.Hasher(), parquet.ByteArray("World")))) + suite.True(bf1.CheckHash(metadata.GetHash(bf1.Hasher(), parquet.ByteArray("Hello")))) + } + } +} + +func TestBloomFilterRoundTrip(t *testing.T) { + suite.Run(t, new(BloomFilterBuilderSuite)) + suite.Run(t, new(EncryptedBloomFilterBuilderSuite)) +} + +func TestReadBloomFilter(t *testing.T) { + dir := os.Getenv("PARQUET_TEST_DATA") + if dir == "" { + t.Skip("PARQUET_TEST_DATA not set") + } + require.DirExists(t, dir) + + files := []string{"data_index_bloom_encoding_stats.parquet", + "data_index_bloom_encoding_with_length.parquet"} + + for _, testfile := range files { + t.Run(testfile, func(t *testing.T) { + rdr, err := file.OpenParquetFile(dir+"/"+testfile, false) + require.NoError(t, err) + defer rdr.Close() + + bloomFilterRdr := rdr.GetBloomFilterReader() + rg0, err := bloomFilterRdr.RowGroup(0) + require.NoError(t, err) + require.NotNil(t, rg0) + + rg1, err := bloomFilterRdr.RowGroup(1) + assert.Nil(t, rg1) + assert.Error(t, err) + assert.ErrorContains(t, err, "row group index 1 out of range") + + bf, err := rg0.GetColumnBloomFilter(0) + require.NoError(t, err) + require.NotNil(t, bf) + + bf1, err := rg0.GetColumnBloomFilter(1) + assert.Nil(t, bf1) + assert.Error(t, err) + assert.ErrorContains(t, err, "column index 1 out of range") + + baBloomFilter := metadata.TypedBloomFilter[parquet.ByteArray]{bf} + assert.True(t, baBloomFilter.Check([]byte("Hello"))) + assert.False(t, baBloomFilter.Check([]byte("NOT_EXISTS"))) + }) + } +} + +func TestBloomFilterReaderFileNotHaveFilter(t *testing.T) { + // can still get a BloomFilterReader and a RowGroupBloomFilterReader + // but cannot get a non-null BloomFilter + dir := os.Getenv("PARQUET_TEST_DATA") + if dir == "" { + t.Skip("PARQUET_TEST_DATA not set") + } + require.DirExists(t, dir) + + rdr, err := file.OpenParquetFile(dir+"/alltypes_plain.parquet", false) + require.NoError(t, err) + defer rdr.Close() + + bloomFilterRdr := rdr.GetBloomFilterReader() + rg0, err := bloomFilterRdr.RowGroup(0) + require.NoError(t, err) + require.NotNil(t, rg0) + + rg1, err := bloomFilterRdr.RowGroup(1) + assert.Nil(t, rg1) + assert.Error(t, err) + assert.ErrorContains(t, err, "row group index 1 out of range") + + bf, err := rg0.GetColumnBloomFilter(0) + require.NoError(t, err) + require.Nil(t, bf) +} diff --git a/parquet/metadata/bloom_filter_test.go b/parquet/metadata/bloom_filter_test.go new file mode 100644 index 00000000..1f8a6011 --- /dev/null +++ b/parquet/metadata/bloom_filter_test.go @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metadata + +import ( + "fmt" + "math/rand/v2" + "runtime" + "testing" + + "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/parquet" + "github.com/stretchr/testify/assert" +) + +func TestSplitBlockFilter(t *testing.T) { + const N = 1000 + const S = 3 + const P = 0.01 + + bf := blockSplitBloomFilter{ + bitset32: make([]uint32, optimalNumBytes(N, P)), + } + + p := rand.New(rand.NewPCG(S, S)) + for i := 0; i < N; i++ { + bf.InsertHash(p.Uint64()) + } + + falsePositives := 0 + p = rand.New(rand.NewPCG(S, S)) + for i := 0; i < N; i++ { + x := p.Uint64() + + if !bf.CheckHash(x) { + t.Fatalf("bloom filter block does not contain value #%d that was inserted %d", i, x) + } + + if bf.CheckHash(^x) { + falsePositives++ + } + } + + if r := (float64(falsePositives) / N); r > P { + t.Fatalf("false positive rate is too high: %f", r) + } +} + +func testHash[T parquet.ColumnTypes](t assert.TestingT, h Hasher, vals []T) { + results := GetHashes(h, vals) + assert.Len(t, results, len(vals)) + for i, v := range vals { + assert.Equal(t, GetHash(h, v), results[i]) + } + + var ( + nvalid = int64(len(vals)) + validBits = make([]byte, bitutil.BytesForBits(2*nvalid)) + spacedVals = make([]T, 2*nvalid) + ) + + for i, v := range vals { + spacedVals[i*2] = v + bitutil.SetBit(validBits, i*2) + + } + + results = GetSpacedHashes(h, nvalid, spacedVals, validBits, 0) + assert.Len(t, results, len(vals)) + for i, v := range vals { + assert.Equal(t, GetHash(h, v), results[i]) + } +} + +func TestGetHashes(t *testing.T) { + var ( + h xxhasher + valsBA = []parquet.ByteArray{ + []byte("hello"), + []byte("world"), + } + + valsFLBA = []parquet.FixedLenByteArray{ + []byte("hello"), + []byte("world"), + } + + valsI32 = []int32{42, 43} + ) + + assert.Len(t, GetSpacedHashes[int32](h, 0, nil, nil, 0), 0) + + testHash(t, h, valsBA) + testHash(t, h, valsFLBA) + testHash(t, h, valsI32) +} + +func TestNewBloomFilter(t *testing.T) { + tests := []struct { + ndv uint32 + fpp float64 + maxBytes int64 + expectedBytes int64 + }{ + {1, 0.09, 0, 0}, + // cap at maximumBloomFilterBytes + {1024 * 1024 * 128, 0.9, maximumBloomFilterBytes + 1, maximumBloomFilterBytes}, + // round to power of 2 + {1024 * 1024, 0.01, maximumBloomFilterBytes, 1 << 21}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("ndv=%d,fpp=%0.3f", tt.ndv, tt.fpp), func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + { + bf := NewBloomFilterFromNDVAndFPP(tt.ndv, tt.fpp, tt.maxBytes, mem) + assert.EqualValues(t, tt.expectedBytes, bf.Size()) + runtime.GC() + } + runtime.GC() // force GC to run and do the cleanup routines + }) + } +} + +func BenchmarkFilterInsert(b *testing.B) { + bf := blockSplitBloomFilter{bitset32: make([]uint32, 8)} + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.InsertHash(uint64(i)) + } + b.SetBytes(bytesPerFilterBlock) +} + +func BenchmarkFilterCheck(b *testing.B) { + bf := blockSplitBloomFilter{bitset32: make([]uint32, 8)} + bf.InsertHash(42) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.CheckHash(42) + } + b.SetBytes(bytesPerFilterBlock) +} + +func BenchmarkFilterCheckBulk(b *testing.B) { + bf := blockSplitBloomFilter{bitset32: make([]uint32, 99*bitsSetPerBlock)} + x := make([]uint64, 16) + r := rand.New(rand.NewPCG(0, 0)) + for i := range x { + x[i] = r.Uint64() + } + + bf.InsertBulk(x) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.CheckBulk(x) + } + b.SetBytes(bytesPerFilterBlock * int64(len(x))) +} + +func BenchmarkFilterInsertBulk(b *testing.B) { + bf := blockSplitBloomFilter{bitset32: make([]uint32, 99*bitsSetPerBlock)} + x := make([]uint64, 16) + r := rand.New(rand.NewPCG(0, 0)) + for i := range x { + x[i] = r.Uint64() + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.InsertBulk(x) + } + b.SetBytes(bytesPerFilterBlock * int64(len(x))) +} diff --git a/parquet/metadata/cleanup_bloom_filter.go b/parquet/metadata/cleanup_bloom_filter.go new file mode 100644 index 00000000..ed835c3b --- /dev/null +++ b/parquet/metadata/cleanup_bloom_filter.go @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.24 + +package metadata + +import ( + "runtime" + "sync" + + "github.com/apache/arrow-go/v18/arrow/memory" +) + +func addCleanup(bf *blockSplitBloomFilter, bufferPool *sync.Pool) { + runtime.AddCleanup(bf, func(data *memory.Buffer) { + if bufferPool != nil { + data.ResizeNoShrink(0) + bufferPool.Put(data) + } else { + data.Release() + } + }, bf.data) +} diff --git a/parquet/metadata/cleanup_bloom_filter_go1.23.go b/parquet/metadata/cleanup_bloom_filter_go1.23.go new file mode 100644 index 00000000..b4bffbe7 --- /dev/null +++ b/parquet/metadata/cleanup_bloom_filter_go1.23.go @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !go1.24 + +package metadata + +import ( + "runtime" + "sync" +) + +func addCleanup(bf *blockSplitBloomFilter, bufferPool *sync.Pool) { + runtime.SetFinalizer(bf, func(f *blockSplitBloomFilter) { + if bufferPool != nil { + f.data.ResizeNoShrink(0) + bufferPool.Put(f.data) + } else { + f.data.Release() + } + }) +} diff --git a/parquet/metadata/column_chunk.go b/parquet/metadata/column_chunk.go index 22c9b9e3..848d20fe 100644 --- a/parquet/metadata/column_chunk.go +++ b/parquet/metadata/column_chunk.go @@ -218,6 +218,13 @@ func (c *ColumnChunkMetaData) BloomFilterOffset() int64 { return c.columnMeta.GetBloomFilterOffset() } +// BloomFilterLength is the length of the serialized bloomfilter including the +// serialized bloom filter header. This was only added in 2.10 so it may not exist, +// returning 0 in that case. +func (c *ColumnChunkMetaData) BloomFilterLength() int32 { + return c.columnMeta.GetBloomFilterLength() +} + // StatsSet returns true only if there are statistics set in the metadata and the column // descriptor has a sort order that is not SortUnknown // @@ -267,6 +274,7 @@ type ColumnChunkMetaDataBuilder struct { compressedSize int64 uncompressedSize int64 + fileOffset int64 } func NewColumnChunkMetaDataBuilder(props *parquet.WriterProperties, column *schema.Column) *ColumnChunkMetaDataBuilder { @@ -347,14 +355,15 @@ type EncodingStats struct { } // Finish finalizes the metadata with the given offsets, -// flushes any compression that needs to be done, and performs -// any encryption if an encryptor is provided. -func (c *ColumnChunkMetaDataBuilder) Finish(info ChunkMetaInfo, hasDict, dictFallback bool, encStats EncodingStats, metaEncryptor encryption.Encryptor) error { +// flushes any compression that needs to be done. +// Encryption will be performed by calling PopulateCryptoData +// after this function is called. +func (c *ColumnChunkMetaDataBuilder) Finish(info ChunkMetaInfo, hasDict, dictFallback bool, encStats EncodingStats) error { if info.DictPageOffset > 0 { c.chunk.MetaData.DictionaryPageOffset = &info.DictPageOffset - c.chunk.FileOffset = info.DictPageOffset + info.CompressedSize + c.fileOffset = info.DictPageOffset } else { - c.chunk.FileOffset = info.DataPageOffset + info.CompressedSize + c.fileOffset = info.DataPageOffset } c.chunk.MetaData.NumValues = info.NumValues @@ -411,6 +420,10 @@ func (c *ColumnChunkMetaDataBuilder) Finish(info ChunkMetaInfo, hasDict, dictFal } c.chunk.MetaData.EncodingStats = thriftEncodingStats + return nil +} + +func (c *ColumnChunkMetaDataBuilder) PopulateCryptoData(encryptor encryption.Encryptor) error { encryptProps := c.props.ColumnEncryptionProperties(c.column.Path()) if encryptProps != nil && encryptProps.IsEncrypted() { ccmd := format.NewColumnCryptoMetaData() @@ -436,7 +449,7 @@ func (c *ColumnChunkMetaDataBuilder) Finish(info ChunkMetaInfo, hasDict, dictFal return err } var buf bytes.Buffer - metaEncryptor.Encrypt(&buf, data) + encryptor.Encrypt(&buf, data) c.chunk.EncryptedColumnMetadata = buf.Bytes() if encryptedFooter { diff --git a/parquet/metadata/file.go b/parquet/metadata/file.go index 39a31929..95d3813b 100644 --- a/parquet/metadata/file.go +++ b/parquet/metadata/file.go @@ -47,6 +47,7 @@ type FileMetaDataBuilder struct { currentRgBldr *RowGroupMetaDataBuilder kvmeta KeyValueMetadata cryptoMetadata *format.FileCryptoMetaData + fileEncryptor encryption.FileEncryptor } // NewFileMetadataBuilder will use the default writer properties if nil is passed for @@ -65,6 +66,10 @@ func NewFileMetadataBuilder(schema *schema.Schema, props *parquet.WriterProperti } } +func (f *FileMetaDataBuilder) SetFileEncryptor(encryptor encryption.FileEncryptor) { + f.fileEncryptor = encryptor +} + // GetFileCryptoMetaData returns the cryptographic information for encrypting/ // decrypting the file. func (f *FileMetaDataBuilder) GetFileCryptoMetaData() *FileCryptoMetadata { @@ -92,6 +97,7 @@ func (f *FileMetaDataBuilder) AppendRowGroup() *RowGroupMetaDataBuilder { rg := format.NewRowGroup() f.rowGroups = append(f.rowGroups, rg) f.currentRgBldr = NewRowGroupMetaDataBuilder(f.props, f.schema, rg) + f.currentRgBldr.fileEncryptor = f.fileEncryptor return f.currentRgBldr } diff --git a/parquet/metadata/metadata_test.go b/parquet/metadata/metadata_test.go index 2f35d636..fccfbe4b 100644 --- a/parquet/metadata/metadata_test.go +++ b/parquet/metadata/metadata_test.go @@ -46,8 +46,8 @@ func generateTableMetaData(schema *schema.Schema, props *parquet.WriterPropertie statsFloat.Signed = true col2Builder.SetStats(statsFloat) - col1Builder.Finish(metadata.ChunkMetaInfo{int64(nrows) / 2, 4, 0, 10, 512, 600}, true, false, metadata.EncodingStats{dictEncodingStats, dataEncodingStats}, nil) - col2Builder.Finish(metadata.ChunkMetaInfo{int64(nrows) / 2, 24, 0, 30, 512, 600}, true, false, metadata.EncodingStats{dictEncodingStats, dataEncodingStats}, nil) + col1Builder.Finish(metadata.ChunkMetaInfo{int64(nrows) / 2, 4, 0, 10, 512, 600}, true, false, metadata.EncodingStats{dictEncodingStats, dataEncodingStats}) + col2Builder.Finish(metadata.ChunkMetaInfo{int64(nrows) / 2, 24, 0, 30, 512, 600}, true, false, metadata.EncodingStats{dictEncodingStats, dataEncodingStats}) rg1Builder.SetNumRows(nrows / 2) rg1Builder.Finish(1024, -1) @@ -60,8 +60,8 @@ func generateTableMetaData(schema *schema.Schema, props *parquet.WriterPropertie col1Builder.SetStats(statsInt) col2Builder.SetStats(statsFloat) dictEncodingStats = make(map[parquet.Encoding]int32) - col1Builder.Finish(metadata.ChunkMetaInfo{int64(nrows) / 2, 0 /*dictionary page offset*/, 0, 10, 512, 600}, false /* has dictionary */, false, metadata.EncodingStats{dictEncodingStats, dataEncodingStats}, nil) - col2Builder.Finish(metadata.ChunkMetaInfo{int64(nrows) / 2, 16, 0, 26, 512, 600}, true, false, metadata.EncodingStats{dictEncodingStats, dataEncodingStats}, nil) + col1Builder.Finish(metadata.ChunkMetaInfo{int64(nrows) / 2, 0 /*dictionary page offset*/, 0, 10, 512, 600}, false /* has dictionary */, false, metadata.EncodingStats{dictEncodingStats, dataEncodingStats}) + col2Builder.Finish(metadata.ChunkMetaInfo{int64(nrows) / 2, 16, 0, 26, 512, 600}, true, false, metadata.EncodingStats{dictEncodingStats, dataEncodingStats}) rg2Builder.SetNumRows(nrows / 2) rg2Builder.Finish(1024, -1) diff --git a/parquet/metadata/row_group.go b/parquet/metadata/row_group.go index e4ec4c76..5ccd2e62 100644 --- a/parquet/metadata/row_group.go +++ b/parquet/metadata/row_group.go @@ -99,6 +99,8 @@ type RowGroupMetaDataBuilder struct { schema *schema.Schema colBuilders []*ColumnChunkMetaDataBuilder nextCol int + + fileEncryptor encryption.FileEncryptor } // NewRowGroupMetaDataBuilder returns a builder using the given properties and underlying thrift object. @@ -166,21 +168,24 @@ func (r *RowGroupMetaDataBuilder) Finish(_ int64, ordinal int16) error { totalUncompressed int64 ) - for idx, col := range r.rg.Columns { - if col.FileOffset < 0 { - return fmt.Errorf("parquet: Column %d is not complete", idx) - } + for idx := range r.rg.Columns { if idx == 0 { - if col.MetaData.IsSetDictionaryPageOffset() && col.MetaData.GetDictionaryPageOffset() > 0 { - fileOffset = col.MetaData.GetDictionaryPageOffset() - } else { - fileOffset = col.MetaData.DataPageOffset - } + fileOffset = r.colBuilders[idx].fileOffset } + // sometimes column metadata is encrypted and not available to read // so we must get total compressed size from column builder totalCompressed += r.colBuilders[idx].TotalCompressedSize() totalUncompressed += r.colBuilders[idx].TotalUncompressedSize() + + if r.fileEncryptor != nil { + enc := r.fileEncryptor.GetColumnMetaEncryptor(r.colBuilders[idx].Descr().Path()) + if enc != nil { + enc.UpdateAad(encryption.CreateModuleAad(enc.FileAad(), encryption.ColumnMetaModule, + ordinal, int16(idx), -1)) + r.colBuilders[idx].PopulateCryptoData(enc) + } + } } if len(r.props.SortingColumns()) > 0 { diff --git a/parquet/pqarrow/column_readers.go b/parquet/pqarrow/column_readers.go index 921774d4..ab08a7b8 100644 --- a/parquet/pqarrow/column_readers.go +++ b/parquet/pqarrow/column_readers.go @@ -49,7 +49,7 @@ type leafReader struct { recordRdr file.RecordReader props ArrowReadProperties - refCount int64 + refCount atomic.Int64 } func newLeafReader(rctx *readerCtx, field *arrow.Field, input *columnIterator, leafInfo file.LevelInfo, props ArrowReadProperties, bufferPool *sync.Pool) (*ColumnReader, error) { @@ -60,18 +60,19 @@ func newLeafReader(rctx *readerCtx, field *arrow.Field, input *columnIterator, l descr: input.Descr(), recordRdr: file.NewRecordReader(input.Descr(), leafInfo, field.Type, rctx.mem, bufferPool), props: props, - refCount: 1, } + ret.refCount.Add(1) + err := ret.nextRowGroup() return &ColumnReader{ret}, err } func (lr *leafReader) Retain() { - atomic.AddInt64(&lr.refCount, 1) + lr.refCount.Add(1) } func (lr *leafReader) Release() { - if atomic.AddInt64(&lr.refCount, -1) == 0 { + if lr.refCount.Add(-1) == 0 { lr.releaseOut() if lr.recordRdr != nil { lr.recordRdr.Release() @@ -92,7 +93,7 @@ func (lr *leafReader) IsOrHasRepeatedChild() bool { return false } func (lr *leafReader) LoadBatch(nrecords int64) (err error) { lr.releaseOut() - lr.recordRdr.Reset() + lr.recordRdr.ResetValues() if err := lr.recordRdr.Reserve(nrecords); err != nil { return err @@ -135,6 +136,16 @@ func (lr *leafReader) clearOut() (out *arrow.Chunked) { func (lr *leafReader) Field() *arrow.Field { return lr.field } +func (lr *leafReader) SeekToRow(rowIdx int64) error { + pr, offset, err := lr.input.FindChunkForRow(rowIdx) + if err != nil { + return err + } + + lr.recordRdr.SetPageReader(pr) + return lr.recordRdr.SeekToRow(offset) +} + func (lr *leafReader) nextRowGroup() error { pr, err := lr.input.NextChunk() if err != nil { @@ -155,15 +166,15 @@ type structReader struct { hasRepeatedChild bool props ArrowReadProperties - refCount int64 + refCount atomic.Int64 } func (sr *structReader) Retain() { - atomic.AddInt64(&sr.refCount, 1) + sr.refCount.Add(1) } func (sr *structReader) Release() { - if atomic.AddInt64(&sr.refCount, -1) == 0 { + if sr.refCount.Add(-1) == 0 { if sr.defRepLevelChild != nil { sr.defRepLevelChild.Release() sr.defRepLevelChild = nil @@ -182,8 +193,8 @@ func newStructReader(rctx *readerCtx, filtered *arrow.Field, levelInfo file.Leve levelInfo: levelInfo, children: children, props: props, - refCount: 1, } + ret.refCount.Add(1) // there could be a mix of children some might be repeated and some might not be // if possible use one that isn't since that will be guaranteed to have the least @@ -227,6 +238,21 @@ func (sr *structReader) GetRepLevels() ([]int16, error) { return sr.defRepLevelChild.GetRepLevels() } +func (sr *structReader) SeekToRow(rowIdx int64) error { + var g errgroup.Group + if !sr.props.Parallel { + g.SetLimit(1) + } + + for _, rdr := range sr.children { + g.Go(func() error { + return rdr.SeekToRow(rowIdx) + }) + } + + return g.Wait() +} + func (sr *structReader) LoadBatch(nrecords int64) error { // Load batches in parallel // When reading structs with large numbers of columns, the serial load is very slow. @@ -323,20 +349,24 @@ type listReader struct { info file.LevelInfo itemRdr *ColumnReader props ArrowReadProperties - refCount int64 + refCount atomic.Int64 } func newListReader(rctx *readerCtx, field *arrow.Field, info file.LevelInfo, childRdr *ColumnReader, props ArrowReadProperties) *ColumnReader { childRdr.Retain() - return &ColumnReader{&listReader{rctx, field, info, childRdr, props, 1}} + lr := &listReader{rctx: rctx, field: field, info: info, itemRdr: childRdr, props: props} + lr.refCount.Add(1) + return &ColumnReader{ + lr, + } } func (lr *listReader) Retain() { - atomic.AddInt64(&lr.refCount, 1) + lr.refCount.Add(1) } func (lr *listReader) Release() { - if atomic.AddInt64(&lr.refCount, -1) == 0 { + if lr.refCount.Add(-1) == 0 { if lr.itemRdr != nil { lr.itemRdr.Release() lr.itemRdr = nil @@ -356,6 +386,10 @@ func (lr *listReader) Field() *arrow.Field { return lr.field } func (lr *listReader) IsOrHasRepeatedChild() bool { return true } +func (lr *listReader) SeekToRow(rowIdx int64) error { + return lr.itemRdr.SeekToRow(rowIdx) +} + func (lr *listReader) LoadBatch(nrecords int64) error { return lr.itemRdr.LoadBatch(nrecords) } @@ -439,12 +473,19 @@ func (lr *listReader) BuildArray(lenBound int64) (*arrow.Chunked, error) { // column reader logic for fixed size lists instead of variable length ones. type fixedSizeListReader struct { - listReader + *listReader } func newFixedSizeListReader(rctx *readerCtx, field *arrow.Field, info file.LevelInfo, childRdr *ColumnReader, props ArrowReadProperties) *ColumnReader { childRdr.Retain() - return &ColumnReader{&fixedSizeListReader{listReader{rctx, field, info, childRdr, props, 1}}} + lr := listReader{rctx: rctx, field: field, info: info, itemRdr: childRdr, props: props} + lr.refCount.Add(1) + + return &ColumnReader{ + &fixedSizeListReader{ + &lr, + }, + } } // helper function to combine chunks into a single array. @@ -584,9 +625,7 @@ func transferBinary(rdr file.RecordReader, dt arrow.DataType) *arrow.Chunked { } func transferInt(rdr file.RecordReader, dt arrow.DataType) arrow.ArrayData { - var ( - output reflect.Value - ) + var output reflect.Value signed := true // create buffer for proper type since parquet only has int32 and int64 @@ -766,9 +805,7 @@ func transferDecimalInteger(rdr file.RecordReader, dt arrow.DataType) arrow.Arra } func uint64FromBigEndianShifted(buf []byte) uint64 { - var ( - bytes [8]byte - ) + var bytes [8]byte copy(bytes[8-len(buf):], buf) return binary.BigEndian.Uint64(bytes[:]) } diff --git a/parquet/pqarrow/file_reader.go b/parquet/pqarrow/file_reader.go index d6eae17a..df736135 100644 --- a/parquet/pqarrow/file_reader.go +++ b/parquet/pqarrow/file_reader.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "slices" "sync" "sync/atomic" @@ -116,6 +117,7 @@ type colReaderImpl interface { GetDefLevels() ([]int16, error) GetRepLevels() ([]int16, error) Field() *arrow.Field + SeekToRow(int64) error IsOrHasRepeatedChild() bool Retain() Release() @@ -387,7 +389,6 @@ func (fr *FileReader) ReadRowGroups(ctx context.Context, indices, rowGroups []in // if the context is in error, but we haven't set an error yet, then it means that the parent context // was cancelled. In this case, we should exit early as some columns may not have been read yet. err = errors.Join(err, ctx.Err()) - if err != nil { // if we encountered an error, consume any waiting data on the channel // so the goroutines don't leak and so memory can get cleaned up. we already @@ -427,6 +428,20 @@ func (fr *FileReader) getColumnReader(ctx context.Context, i int, colFactory itr type RecordReader interface { array.RecordReader arrio.Reader + // SeekToRow will shift the record reader so that subsequent calls to Read + // or Next will begin from the specified row. + // + // If the record reader was constructed with a request for a subset of row + // groups, then rows are counted across the requested row groups, not the + // entire file. This prevents reading row groups that were requested to be + // skipped, and allows treating the subset of row groups as a single collection + // of rows. + // + // If the file contains Offset indexes for a given column, then it will be + // utilized to skip pages as needed to find the requested row. Otherwise page + // headers will have to still be read to find the right page to being reading + // from. + SeekToRow(int64) error } // GetRecordReader returns a record reader that reads only the requested column indexes and row groups. @@ -475,14 +490,15 @@ func (fr *FileReader) GetRecordReader(ctx context.Context, colIndices, rowGroups if fr.Props.BatchSize <= 0 { batchSize = nrows } - return &recordReader{ + rr := &recordReader{ numRows: nrows, batchSize: batchSize, parallel: fr.Props.Parallel, sc: sc, fieldReaders: readers, - refCount: 1, - }, nil + } + rr.refCount.Add(1) + return rr, nil } func (fr *FileReader) getReader(ctx context.Context, field *SchemaField, arrowField arrow.Field) (out *ColumnReader, err error) { @@ -537,17 +553,15 @@ func (fr *FileReader) getReader(ctx context.Context, field *SchemaField, arrowFi } // because we performed getReader concurrently, we need to prune out any empty readers - for n := len(childReaders) - 1; n >= 0; n-- { - if childReaders[n] == nil { - childReaders = append(childReaders[:n], childReaders[n+1:]...) - childFields = append(childFields[:n], childFields[n+1:]...) - } - } + childReaders = slices.DeleteFunc(childReaders, + func(r *ColumnReader) bool { return r == nil }) if len(childFields) == 0 { return nil, nil } - filtered := arrow.Field{Name: arrowField.Name, Nullable: arrowField.Nullable, - Metadata: arrowField.Metadata, Type: arrow.StructOf(childFields...)} + filtered := arrow.Field{ + Name: arrowField.Name, Nullable: arrowField.Nullable, + Metadata: arrowField.Metadata, Type: arrow.StructOf(childFields...), + } out = newStructReader(&rctx, &filtered, field.LevelInfo, childReaders, fr.Props) case arrow.LIST, arrow.FIXED_SIZE_LIST, arrow.MAP: child := field.Children[0] @@ -615,15 +629,45 @@ type columnIterator struct { rdr *file.Reader schema *schema.Schema rowGroups []int + + rgIdx int } -func (c *columnIterator) NextChunk() (file.PageReader, error) { +func (c *columnIterator) FindChunkForRow(rowIdx int64) (file.PageReader, int64, error) { if len(c.rowGroups) == 0 { + return nil, 0, nil + } + + if rowIdx < 0 || rowIdx > c.rdr.NumRows() { + return nil, 0, fmt.Errorf("invalid row index %d, file only has %d rows", rowIdx, c.rdr.NumRows()) + } + + idx := int64(0) + for i, rg := range c.rowGroups { + rgr := c.rdr.RowGroup(rg) + if idx+rgr.NumRows() > rowIdx { + c.rgIdx = i + 1 + pr, err := rgr.GetColumnPageReader(c.index) + if err != nil { + return nil, 0, err + } + + return pr, rowIdx - idx, nil + } + idx += rgr.NumRows() + } + + return nil, 0, fmt.Errorf("%w: invalid row index %d, row group subset only has %d total rows", + arrow.ErrInvalid, rowIdx, idx) +} + +func (c *columnIterator) NextChunk() (file.PageReader, error) { + if len(c.rowGroups) == 0 || c.rgIdx >= len(c.rowGroups) { return nil, nil } - rgr := c.rdr.RowGroup(c.rowGroups[0]) - c.rowGroups = c.rowGroups[1:] + rgr := c.rdr.RowGroup(c.rowGroups[c.rgIdx]) + c.rgIdx++ return rgr.GetColumnPageReader(c.index) } @@ -640,15 +684,34 @@ type recordReader struct { cur arrow.Record err error - refCount int64 + refCount atomic.Int64 +} + +func (r *recordReader) SeekToRow(row int64) error { + if r.cur != nil { + r.cur.Release() + r.cur = nil + } + + if row < 0 || row >= r.numRows { + return fmt.Errorf("invalid row index %d, file only has %d rows", row, r.numRows) + } + + for _, fr := range r.fieldReaders { + if err := fr.SeekToRow(row); err != nil { + return err + } + } + + return nil } func (r *recordReader) Retain() { - atomic.AddInt64(&r.refCount, 1) + r.refCount.Add(1) } func (r *recordReader) Release() { - if atomic.AddInt64(&r.refCount, -1) == 0 { + if r.refCount.Add(-1) == 0 { if r.cur != nil { r.cur.Release() r.cur = nil diff --git a/parquet/pqarrow/file_reader_test.go b/parquet/pqarrow/file_reader_test.go index 9010927e..bca51647 100644 --- a/parquet/pqarrow/file_reader_test.go +++ b/parquet/pqarrow/file_reader_test.go @@ -285,6 +285,173 @@ func TestRecordReaderSerial(t *testing.T) { assert.Nil(t, rec) } +func TestRecordReaderSeekToRow(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + tbl := makeDateTimeTypesTable(mem, true, true) + defer tbl.Release() + + var buf bytes.Buffer + require.NoError(t, pqarrow.WriteTable(tbl, &buf, tbl.NumRows(), nil, pqarrow.NewArrowWriterProperties(pqarrow.WithAllocator(mem)))) + + pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), file.WithReadProps(parquet.NewReaderProperties(mem))) + require.NoError(t, err) + + reader, err := pqarrow.NewFileReader(pf, pqarrow.ArrowReadProperties{BatchSize: 2}, mem) + require.NoError(t, err) + + sc, err := reader.Schema() + assert.NoError(t, err) + assert.Truef(t, tbl.Schema().Equal(sc), "expected: %s\ngot: %s", tbl.Schema(), sc) + + rr, err := reader.GetRecordReader(context.Background(), nil, nil) + assert.NoError(t, err) + assert.NotNil(t, rr) + defer rr.Release() + + tr := array.NewTableReader(tbl, 2) + defer tr.Release() + + rec, err := rr.Read() + assert.NoError(t, err) + tr.Next() + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + require.NoError(t, rr.SeekToRow(0)) + rec, err = rr.Read() + assert.NoError(t, err) + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + rec, err = rr.Read() + assert.NoError(t, err) + tr.Next() + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + require.NoError(t, rr.SeekToRow(2)) + rec, err = rr.Read() + assert.NoError(t, err) + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + require.NoError(t, rr.SeekToRow(4)) + rec, err = rr.Read() + tr.Next() + assert.NoError(t, err) + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) +} + +func TestRecordReaderMultiRowGroup(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + tbl := makeDateTimeTypesTable(mem, true, true) + defer tbl.Release() + + var buf bytes.Buffer + require.NoError(t, pqarrow.WriteTable(tbl, &buf, 2, nil, pqarrow.NewArrowWriterProperties(pqarrow.WithAllocator(mem)))) + + pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), file.WithReadProps(parquet.NewReaderProperties(mem))) + require.NoError(t, err) + + reader, err := pqarrow.NewFileReader(pf, pqarrow.ArrowReadProperties{BatchSize: 2}, mem) + require.NoError(t, err) + + sc, err := reader.Schema() + assert.NoError(t, err) + assert.Truef(t, tbl.Schema().Equal(sc), "expected: %s\ngot: %s", tbl.Schema(), sc) + + rr, err := reader.GetRecordReader(context.Background(), nil, nil) + assert.NoError(t, err) + assert.NotNil(t, rr) + defer rr.Release() + + tr := array.NewTableReader(tbl, 2) + defer tr.Release() + + rec, err := rr.Read() + assert.NoError(t, err) + tr.Next() + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + rec, err = rr.Read() + assert.NoError(t, err) + tr.Next() + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + rec, err = rr.Read() + assert.NoError(t, err) + tr.Next() + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + rec, err = rr.Read() + assert.Same(t, io.EOF, err) + assert.Nil(t, rec) +} + +func TestRecordReaderSeekToRowMultiRowGroup(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + tbl := makeDateTimeTypesTable(mem, true, true) + defer tbl.Release() + + var buf bytes.Buffer + require.NoError(t, pqarrow.WriteTable(tbl, &buf, 2, nil, pqarrow.NewArrowWriterProperties(pqarrow.WithAllocator(mem)))) + + pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), file.WithReadProps(parquet.NewReaderProperties(mem))) + require.NoError(t, err) + + reader, err := pqarrow.NewFileReader(pf, pqarrow.ArrowReadProperties{BatchSize: 2}, mem) + require.NoError(t, err) + + sc, err := reader.Schema() + assert.NoError(t, err) + assert.Truef(t, tbl.Schema().Equal(sc), "expected: %s\ngot: %s", tbl.Schema(), sc) + + rr, err := reader.GetRecordReader(context.Background(), nil, nil) + assert.NoError(t, err) + assert.NotNil(t, rr) + defer rr.Release() + + tr := array.NewTableReader(tbl, 2) + defer tr.Release() + + rec, err := rr.Read() + assert.NoError(t, err) + tr.Next() + first := tr.Record() + first.Retain() + defer first.Release() + + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + require.NoError(t, rr.SeekToRow(0)) + rec, err = rr.Read() + assert.NoError(t, err) + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + rec, err = rr.Read() + assert.NoError(t, err) + tr.Next() + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + require.NoError(t, rr.SeekToRow(2)) + rec, err = rr.Read() + assert.NoError(t, err) + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + require.NoError(t, rr.SeekToRow(4)) + rec, err = rr.Read() + tr.Next() + assert.NoError(t, err) + assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec) + + require.NoError(t, rr.SeekToRow(0)) + rec, err = rr.Read() + assert.NoError(t, err) + assert.Truef(t, array.RecordEqual(first, rec), "expected: %s\ngot: %s", first, rec) +} + func TestFileReaderWriterMetadata(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) diff --git a/parquet/pqarrow/file_writer.go b/parquet/pqarrow/file_writer.go index 45cfe49d..e4e99368 100644 --- a/parquet/pqarrow/file_writer.go +++ b/parquet/pqarrow/file_writer.go @@ -338,3 +338,10 @@ func (fw *FileWriter) WriteColumnData(data arrow.Array) error { defer chunked.Release() return fw.WriteColumnChunked(chunked, 0, int64(data.Len())) } + +// FileMetadata returns the current state of the FileMetadata that would be written +// if this file were to be closed. If the file has already been closed, then this +// will return the FileMetaData which was written to the file. +func (fw *FileWriter) FileMetadata() (*metadata.FileMetaData, error) { + return fw.wr.FileMetadata() +} diff --git a/parquet/pqarrow/path_builder.go b/parquet/pqarrow/path_builder.go index 92b12560..7abd0f5c 100644 --- a/parquet/pqarrow/path_builder.go +++ b/parquet/pqarrow/path_builder.go @@ -395,7 +395,7 @@ func (p *pathBuilder) Visit(arr arrow.Array) error { repLevel: p.info.maxRepLevel, defLevelIfEmpty: p.info.maxDefLevel - 1, }) - p.nullableInParent = ok + p.nullableInParent = arr.DataType().(arrow.ListLikeType).ElemField().Nullable return p.Visit(larr.ListValues()) case arrow.FIXED_SIZE_LIST: p.maybeAddNullable(arr) diff --git a/parquet/pqarrow/path_builder_test.go b/parquet/pqarrow/path_builder_test.go index bb9d8bf4..df548bcd 100644 --- a/parquet/pqarrow/path_builder_test.go +++ b/parquet/pqarrow/path_builder_test.go @@ -39,7 +39,8 @@ func TestNonNullableSingleList(t *testing.T) { // So: // def level 0: a null entry // def level 1: a non-null entry - bldr := array.NewListBuilder(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64) + bldr := array.NewBuilder(memory.DefaultAllocator, + arrow.ListOfNonNullable(arrow.PrimitiveTypes.Int64)).(*array.ListBuilder) defer bldr.Release() vb := bldr.ValueBuilder().(*array.Int64Builder) @@ -67,7 +68,7 @@ func TestNonNullableSingleList(t *testing.T) { result, err := mp.write(0, ctx) require.NoError(t, err) - assert.Equal(t, []int16{2, 2, 2, 2, 2, 2}, result.defLevels) + assert.Equal(t, []int16{1, 1, 1, 1, 1, 1}, result.defLevels) assert.Equal(t, []int16{0, 0, 1, 0, 1, 1}, result.repLevels) assert.Len(t, result.postListVisitedElems, 1) assert.EqualValues(t, 0, result.postListVisitedElems[0].start) diff --git a/parquet/pqarrow/schema.go b/parquet/pqarrow/schema.go index 416d59f1..7dacd789 100644 --- a/parquet/pqarrow/schema.go +++ b/parquet/pqarrow/schema.go @@ -18,6 +18,7 @@ package pqarrow import ( "encoding/base64" + "errors" "fmt" "math" "strconv" @@ -240,9 +241,26 @@ func repFromNullable(isnullable bool) parquet.Repetition { return parquet.Repetitions.Required } -func structToNode(typ *arrow.StructType, name string, nullable bool, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { +func variantToNode(t *variantExtensionType, field arrow.Field, props *parquet.WriterProperties, arrProps ArrowWriterProperties) (schema.Node, error) { + metadataNode, err := fieldToNode("metadata", t.Metadata(), props, arrProps) + if err != nil { + return nil, err + } + + valueNode, err := fieldToNode("value", t.Value(), props, arrProps) + if err != nil { + return nil, err + } + + return schema.NewGroupNodeLogical(field.Name, repFromNullable(field.Nullable), + schema.FieldList{metadataNode, valueNode}, schema.VariantLogicalType{}, + fieldIDFromMeta(field.Metadata)) +} + +func structToNode(field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { + typ := field.Type.(*arrow.StructType) if typ.NumFields() == 0 { - return nil, fmt.Errorf("cannot write struct type '%s' with no children field to parquet. Consider adding a dummy child", name) + return nil, fmt.Errorf("cannot write struct type '%s' with no children field to parquet. Consider adding a dummy child", field.Name) } children := make(schema.FieldList, 0, typ.NumFields()) @@ -254,7 +272,7 @@ func structToNode(typ *arrow.StructType, name string, nullable bool, props *parq children = append(children, n) } - return schema.NewGroupNode(name, repFromNullable(nullable), children, -1) + return schema.NewGroupNode(field.Name, repFromNullable(field.Nullable), children, fieldIDFromMeta(field.Metadata)) } func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { @@ -267,7 +285,7 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties return nil, xerrors.New("nulltype arrow field must be nullable") } case arrow.STRUCT: - return structToNode(field.Type.(*arrow.StructType), field.Name, field.Nullable, props, arrprops) + return structToNode(field, props, arrprops) case arrow.FIXED_SIZE_LIST, arrow.LIST: elemField := field.Type.(arrow.ListLikeType).ElemField() @@ -276,7 +294,7 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties return nil, err } - return schema.ListOfWithName(name, child, repFromNullable(field.Nullable), -1) + return schema.ListOfWithName(name, child, repFromNullable(field.Nullable), fieldIDFromMeta(field.Metadata)) case arrow.DICTIONARY: // parquet has no dictionary type, dictionary is encoding, not schema level dictType := field.Type.(*arrow.DictionaryType) @@ -302,9 +320,14 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties } return schema.NewGroupNode(field.Name, repFromNullable(field.Nullable), schema.FieldList{ keyvalNode, - }, -1) + }, fieldIDFromMeta(field.Metadata)) + } + return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), fieldIDFromMeta(field.Metadata)) + case arrow.EXTENSION: + extType := field.Type.(arrow.ExtensionType) + if extType.ExtensionName() == "parquet.variant" { + return variantToNode(extType.(*variantExtensionType), field, props, arrprops) } - return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), -1) } // Not a GroupNode @@ -830,11 +853,29 @@ func mapToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *sc return nil } +func variantToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error { + // this is for unshredded variants. shredded variants may have more fields + if n.NumFields() != 2 { + return errors.New("VARIANT group must have exactly 2 children") + } + + var err error + if err = groupToStructField(n, currentLevels, ctx, out); err != nil { + return err + } + + storageType := out.Field.Type + out.Field.Type, err = newVariantType(storageType) + return err +} + func groupToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error { if n.LogicalType().Equals(schema.NewListLogicalType()) { return listToSchemaField(n, currentLevels, ctx, parent, out) } else if n.LogicalType().Equals(schema.MapLogicalType{}) { return mapToSchemaField(n, currentLevels, ctx, parent, out) + } else if n.LogicalType().Equals(schema.VariantLogicalType{}) { + return variantToSchemaField(n, currentLevels, ctx, parent, out) } if n.RepetitionType() == parquet.Repetitions.Repeated { diff --git a/parquet/pqarrow/schema_test.go b/parquet/pqarrow/schema_test.go index f075b466..ef03ae45 100644 --- a/parquet/pqarrow/schema_test.go +++ b/parquet/pqarrow/schema_test.go @@ -292,7 +292,7 @@ func TestConvertArrowFloat16(t *testing.T) { } } -func TestCoerceTImestampV1(t *testing.T) { +func TestCoerceTimestampV1(t *testing.T) { parquetFields := make(schema.FieldList, 0) arrowFields := make([]arrow.Field, 0) @@ -311,7 +311,7 @@ func TestCoerceTImestampV1(t *testing.T) { } } -func TestAutoCoerceTImestampV1(t *testing.T) { +func TestAutoCoerceTimestampV1(t *testing.T) { parquetFields := make(schema.FieldList, 0) arrowFields := make([]arrow.Field, 0) @@ -402,7 +402,7 @@ func TestListStructBackwardCompatible(t *testing.T) { schema.StringLogicalType{}, parquet.Types.ByteArray, -1, 3)), schema.MustPrimitive(schema.NewPrimitiveNodeLogical("class", parquet.Repetitions.Optional, schema.StringLogicalType{}, parquet.Types.ByteArray, -1, 4)), - }, -1)), + }, 5)), }, schema.NewListLogicalType(), 1)), }, -1))) @@ -417,7 +417,7 @@ func TestListStructBackwardCompatible(t *testing.T) { Metadata: arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"3"})}, arrow.Field{Name: "class", Type: arrow.BinaryTypes.String, Nullable: true, Metadata: arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"4"})}, - ), Nullable: true, Metadata: arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"-1"})}), + ), Nullable: true, Metadata: arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"5"})}), Nullable: true, Metadata: arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"1"})}, }, nil) @@ -472,3 +472,35 @@ func TestProperListElementNullability(t *testing.T) { require.NoError(t, err) assert.True(t, arrSchema.Equal(outSchema), "expected: %s, got: %s", arrSchema, outSchema) } + +func TestConvertSchemaParquetVariant(t *testing.T) { + // unshredded variant: + // optional group variant_col { + // required binary metadata; + // required binary value; + // } + // + // shredded variants will be added later + parquetFields := make(schema.FieldList, 0) + metadata := schema.NewByteArrayNode("metadata", parquet.Repetitions.Required, -1) + value := schema.NewByteArrayNode("value", parquet.Repetitions.Required, -1) + + variant, err := schema.NewGroupNodeLogical("variant_unshredded", parquet.Repetitions.Optional, + schema.FieldList{metadata, value}, schema.VariantLogicalType{}, -1) + require.NoError(t, err) + parquetFields = append(parquetFields, variant) + + pqschema := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Required, parquetFields, -1))) + outSchema, err := pqarrow.FromParquet(pqschema, nil, nil) + require.NoError(t, err) + + assert.EqualValues(t, 1, outSchema.NumFields()) + assert.Equal(t, "variant_unshredded", outSchema.Field(0).Name) + assert.Equal(t, arrow.EXTENSION, outSchema.Field(0).Type.ID()) + + assert.Equal(t, "parquet.variant", outSchema.Field(0).Type.(arrow.ExtensionType).ExtensionName()) + + sc, err := pqarrow.ToParquet(outSchema, nil, pqarrow.DefaultWriterProps()) + require.NoError(t, err) + assert.True(t, pqschema.Equals(sc), pqschema.String(), sc.String()) +} diff --git a/parquet/pqarrow/variant.go b/parquet/pqarrow/variant.go new file mode 100644 index 00000000..4f836c06 --- /dev/null +++ b/parquet/pqarrow/variant.go @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pqarrow + +import ( + "fmt" + "reflect" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/parquet/schema" +) + +// variantArray is an experimental extension type, but is not yet fully supported. +type variantArray struct { + array.ExtensionArrayBase +} + +// variantExtensionType is experimental extension type that supports +// semi-structured objects that can be composed of primitives, arrays, and +// objects which can be queried by path. +// +// Unshredded variant representation: +// +// optional group variant_name (VARIANT) { +// required binary metadata; +// required binary value; +// } +// +// To read more about variant encoding, see the variant encoding spec at +// https://github.com/apache/parquet-format/blob/master/VariantEncoding.md +// +// To read more about variant shredding, see the variant shredding spec at +// https://github.com/apache/parquet-format/blob/master/VariantShredding.md +type variantExtensionType struct { + arrow.ExtensionBase + + // TODO: add shredded_value + metadata arrow.Field + value arrow.Field +} + +func (*variantExtensionType) ParquetLogicalType() schema.LogicalType { + return schema.VariantLogicalType{} +} + +func isBinaryField(f arrow.Field) bool { + return f.Type.ID() == arrow.BINARY || f.Type.ID() == arrow.LARGE_BINARY +} + +func isSupportedVariantStorage(dt arrow.DataType) bool { + // for now we only support unshredded variants. unshredded vairant storage + // type should be a struct with a binary metadata and binary value. + // + // In shredded variants, the binary value field can be replaced + // with one or more of the following: object, array, typed_value, and variant_value. + s, ok := dt.(*arrow.StructType) + if !ok { + return false + } + + if s.NumFields() != 2 { + return false + } + + // ordering of metadata and value fields does not matter, as we will + // assign these to the variant extension type's members. + // here we just need to check that both are present. + metadataField, ok := s.FieldByName("metadata") + if !ok { + return false + } + + valueField, ok := s.FieldByName("value") + if !ok { + return false + } + + // both must be non-nullable binary types for unshredded variants for now + return isBinaryField(metadataField) && isBinaryField(valueField) && + !metadataField.Nullable && !valueField.Nullable +} + +// NOTE: this is still experimental, a future change will add shredding support. +func newVariantType(storageType arrow.DataType) (*variantExtensionType, error) { + if !isSupportedVariantStorage(storageType) { + return nil, fmt.Errorf("%w: invalid storage type for unshredded variant: %s", + arrow.ErrInvalid, storageType.String()) + } + + var ( + mdField, valField arrow.Field + ) + + // shredded variants will eventually need to handle an optional shredded_value + // as well as value being optional + dt := storageType.(*arrow.StructType) + if dt.Field(0).Name == "metadata" { + mdField = dt.Field(0) + valField = dt.Field(1) + } else { + mdField = dt.Field(1) + valField = dt.Field(0) + } + + return &variantExtensionType{ + ExtensionBase: arrow.ExtensionBase{Storage: storageType}, + metadata: mdField, + value: valField, + }, nil +} + +func (v *variantExtensionType) Metadata() arrow.Field { return v.metadata } +func (v *variantExtensionType) Value() arrow.Field { return v.value } + +func (*variantExtensionType) ArrayType() reflect.Type { + return reflect.TypeOf(variantArray{}) +} + +func (*variantExtensionType) ExtensionName() string { + return "parquet.variant" +} + +func (v *variantExtensionType) String() string { + return fmt.Sprintf("extension<%s>", v.ExtensionName()) +} + +func (v *variantExtensionType) ExtensionEquals(other arrow.ExtensionType) bool { + return v.ExtensionName() == other.ExtensionName() && + arrow.TypeEqual(v.Storage, other.StorageType()) +} + +func (*variantExtensionType) Serialize() string { return "" } +func (*variantExtensionType) Deserialize(storageType arrow.DataType, _ string) (arrow.ExtensionType, error) { + return newVariantType(storageType) +} diff --git a/parquet/pqarrow/variant_test.go b/parquet/pqarrow/variant_test.go new file mode 100644 index 00000000..37fe6bb3 --- /dev/null +++ b/parquet/pqarrow/variant_test.go @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pqarrow + +import ( + "testing" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestVariantExtensionType(t *testing.T) { + variant1, err := newVariantType(arrow.StructOf( + arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false}, + arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false})) + require.NoError(t, err) + variant2, err := newVariantType(arrow.StructOf( + arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false}, + arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false})) + require.NoError(t, err) + + assert.True(t, arrow.TypeEqual(variant1, variant2)) + + // can be provided in either order + variantFieldsFlipped, err := newVariantType(arrow.StructOf( + arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false}, + arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false})) + require.NoError(t, err) + + assert.Equal(t, "metadata", variantFieldsFlipped.Metadata().Name) + assert.Equal(t, "value", variantFieldsFlipped.Value().Name) + + invalidTypes := []arrow.DataType{ + arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary}), + arrow.StructOf(arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary}), + arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary}, + arrow.Field{Name: "value", Type: arrow.PrimitiveTypes.Int32}), + arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary}, + arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary}, + arrow.Field{Name: "extra", Type: arrow.BinaryTypes.Binary}), + arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: true}, + arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false}), + arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false}, + arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true}), + } + + for _, tt := range invalidTypes { + _, err := newVariantType(tt) + assert.Error(t, err) + assert.ErrorContains(t, err, "invalid storage type for unshredded variant: "+tt.String()) + } +} diff --git a/parquet/schema/logical_types.go b/parquet/schema/logical_types.go index 1e0ed949..0c0ce559 100644 --- a/parquet/schema/logical_types.go +++ b/parquet/schema/logical_types.go @@ -70,6 +70,8 @@ func getLogicalType(l *format.LogicalType) LogicalType { return UUIDLogicalType{} case l.IsSetFLOAT16(): return Float16LogicalType{} + case l.IsSetVARIANT(): + return VariantLogicalType{} case l == nil: return NoLogicalType{} default: @@ -1110,6 +1112,41 @@ func (Float16LogicalType) Equals(rhs LogicalType) bool { return ok } +type VariantLogicalType struct{ baseLogicalType } + +func (VariantLogicalType) IsNested() bool { return true } + +func (VariantLogicalType) SortOrder() SortOrder { + return SortUNKNOWN +} + +func (VariantLogicalType) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]string{"Type": VariantLogicalType{}.String()}) +} + +func (VariantLogicalType) String() string { + return "Variant" +} + +func (VariantLogicalType) ToConvertedType() (ConvertedType, DecimalMetadata) { + return ConvertedTypes.None, DecimalMetadata{} +} + +func (VariantLogicalType) IsCompatible(ct ConvertedType, _ DecimalMetadata) bool { + return ct == ConvertedTypes.None +} + +func (VariantLogicalType) IsApplicable(parquet.Type, int32) bool { return false } + +func (VariantLogicalType) toThrift() *format.LogicalType { + return &format.LogicalType{VARIANT: format.NewVariantType()} +} + +func (VariantLogicalType) Equals(rhs LogicalType) bool { + _, ok := rhs.(VariantLogicalType) + return ok +} + type NullLogicalType struct{ baseLogicalType } func (NullLogicalType) SortOrder() SortOrder { diff --git a/parquet/schema/logical_types_test.go b/parquet/schema/logical_types_test.go index c4f3e091..62d9294b 100644 --- a/parquet/schema/logical_types_test.go +++ b/parquet/schema/logical_types_test.go @@ -160,6 +160,7 @@ func TestNewTypeIncompatibility(t *testing.T) { {"uuid", schema.UUIDLogicalType{}, schema.UUIDLogicalType{}}, {"float16", schema.Float16LogicalType{}, schema.Float16LogicalType{}}, {"null", schema.NullLogicalType{}, schema.NullLogicalType{}}, + {"variant", schema.VariantLogicalType{}, schema.VariantLogicalType{}}, {"not-utc-time_milli", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimeLogicalType{}}, {"not-utc-time-micro", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimeLogicalType{}}, {"not-utc-time-nano", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimeLogicalType{}}, @@ -226,6 +227,7 @@ func TestLogicalTypeProperties(t *testing.T) { {"bson", schema.BSONLogicalType{}, false, true, true}, {"uuid", schema.UUIDLogicalType{}, false, true, true}, {"float16", schema.Float16LogicalType{}, false, true, true}, + {"variant", schema.VariantLogicalType{}, true, true, true}, {"nological", schema.NoLogicalType{}, false, false, true}, {"unknown", schema.UnknownLogicalType{}, false, false, false}, } @@ -456,6 +458,7 @@ func TestLogicalTypeRepresentation(t *testing.T) { {"bson", schema.BSONLogicalType{}, "BSON", `{"Type": "BSON"}`}, {"uuid", schema.UUIDLogicalType{}, "UUID", `{"Type": "UUID"}`}, {"float16", schema.Float16LogicalType{}, "Float16", `{"Type": "Float16"}`}, + {"variant", schema.VariantLogicalType{}, "Variant", `{"Type": "Variant"}`}, {"none", schema.NoLogicalType{}, "None", `{"Type": "None"}`}, } @@ -502,6 +505,7 @@ func TestLogicalTypeSortOrder(t *testing.T) { {"bson", schema.BSONLogicalType{}, schema.SortUNSIGNED}, {"uuid", schema.UUIDLogicalType{}, schema.SortUNSIGNED}, {"float16", schema.Float16LogicalType{}, schema.SortSIGNED}, + {"variant", schema.VariantLogicalType{}, schema.SortUNKNOWN}, {"none", schema.NoLogicalType{}, schema.SortUNKNOWN}, } diff --git a/parquet/writer_properties.go b/parquet/writer_properties.go index 3ee5f79e..3151e2f8 100644 --- a/parquet/writer_properties.go +++ b/parquet/writer_properties.go @@ -17,6 +17,7 @@ package parquet import ( + "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/apache/arrow-go/v18/parquet/compress" format "github.com/apache/arrow-go/v18/parquet/internal/gen-go/parquet" @@ -49,41 +50,60 @@ const ( DefaultMaxStatsSize int64 = 4096 // Default is to not write page indexes for columns DefaultPageIndexEnabled = false - DefaultCreatedBy = "parquet-go version 18.0.0-SNAPSHOT" + DefaultCreatedBy = "parquet-go version " + arrow.PkgVersion DefaultRootName = "schema" + + DefaultMaxBloomFilterBytes = 1024 * 1024 + DefaultBloomFilterEnabled = false + DefaultBloomFilterFPP = 0.01 + DefaultAdaptiveBloomFilterEnabled = false + DefaultBloomFilterCandidates = 5 ) // ColumnProperties defines the encoding, codec, and so on for a given column. type ColumnProperties struct { - Encoding Encoding - Codec compress.Compression - DictionaryEnabled bool - StatsEnabled bool - PageIndexEnabled bool - MaxStatsSize int64 - CompressionLevel int + Encoding Encoding + Codec compress.Compression + DictionaryEnabled bool + StatsEnabled bool + PageIndexEnabled bool + MaxStatsSize int64 + CompressionLevel int + BloomFilterEnabled bool + BloomFilterFPP float64 + AdaptiveBloomFilterEnabled bool + BloomFilterCandidates int + BloomFilterNDV int64 } // DefaultColumnProperties returns the default properties which get utilized for writing. // // The default column properties are the following constants: // -// Encoding: Encodings.Plain -// Codec: compress.Codecs.Uncompressed -// DictionaryEnabled: DefaultDictionaryEnabled -// StatsEnabled: DefaultStatsEnabled -// PageIndexEnabled: DefaultPageIndexEnabled -// MaxStatsSize: DefaultMaxStatsSize -// CompressionLevel: compress.DefaultCompressionLevel +// Encoding: Encodings.Plain +// Codec: compress.Codecs.Uncompressed +// DictionaryEnabled: DefaultDictionaryEnabled +// StatsEnabled: DefaultStatsEnabled +// PageIndexEnabled: DefaultPageIndexEnabled +// MaxStatsSize: DefaultMaxStatsSize +// CompressionLevel: compress.DefaultCompressionLevel +// BloomFilterEnabled: DefaultBloomFilterEnabled +// BloomFilterFPP: DefaultBloomFilterFPP +// AdaptiveBloomFilterEnabled: DefaultAdaptiveBloomFilterEnabled +// BloomFilterCandidates: DefaultBloomFilterCandidates func DefaultColumnProperties() ColumnProperties { return ColumnProperties{ - Encoding: Encodings.Plain, - Codec: compress.Codecs.Uncompressed, - DictionaryEnabled: DefaultDictionaryEnabled, - StatsEnabled: DefaultStatsEnabled, - PageIndexEnabled: DefaultPageIndexEnabled, - MaxStatsSize: DefaultMaxStatsSize, - CompressionLevel: compress.DefaultCompressionLevel, + Encoding: Encodings.Plain, + Codec: compress.Codecs.Uncompressed, + DictionaryEnabled: DefaultDictionaryEnabled, + StatsEnabled: DefaultStatsEnabled, + PageIndexEnabled: DefaultPageIndexEnabled, + MaxStatsSize: DefaultMaxStatsSize, + CompressionLevel: compress.DefaultCompressionLevel, + BloomFilterEnabled: DefaultBloomFilterEnabled, + BloomFilterFPP: DefaultBloomFilterFPP, + AdaptiveBloomFilterEnabled: DefaultAdaptiveBloomFilterEnabled, + BloomFilterCandidates: DefaultBloomFilterCandidates, } } @@ -91,13 +111,18 @@ func DefaultColumnProperties() ColumnProperties { type SortingColumn = format.SortingColumn type writerPropConfig struct { - wr *WriterProperties - encodings map[string]Encoding - codecs map[string]compress.Compression - compressLevel map[string]int - dictEnabled map[string]bool - statsEnabled map[string]bool - indexEnabled map[string]bool + wr *WriterProperties + encodings map[string]Encoding + codecs map[string]compress.Compression + compressLevel map[string]int + dictEnabled map[string]bool + statsEnabled map[string]bool + indexEnabled map[string]bool + bloomFilterNDVs map[string]int64 + bloomFilterFPPs map[string]float64 + bloomFilterEnabled map[string]bool + adaptiveBloomFilterEnabled map[string]bool + numBloomFilterCandidates map[string]int } // WriterProperty is used as the options for building a writer properties instance @@ -337,20 +362,142 @@ func WithPageIndexEnabledPath(path ColumnPath, enabled bool) WriterProperty { return WithPageIndexEnabledFor(path.String(), enabled) } +// WithMaxBloomFilterBytes sets the maximum size for a bloom filter, after which +// it is abandoned and not written to the file. +func WithMaxBloomFilterBytes(nbytes int64) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.wr.maxBloomFilterBytes = nbytes + } +} + +// WithBloomFilterEnabled sets the default value for whether to enable writing bloom +// filters for columns. This is the default value for all columns, but can be overridden +// by using WithBloomFilterEnabledFor or WithBloomFilterEnabledPath. +func WithBloomFilterEnabled(enabled bool) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.wr.defColumnProps.BloomFilterEnabled = enabled + } +} + +// WithBloomFilterEnabledFor specifies a per column value as to enable or disable writing +// bloom filters for the column. +func WithBloomFilterEnabledFor(path string, enabled bool) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.bloomFilterEnabled[path] = enabled + } +} + +// WithBloomFilterEnabledPath is like WithBloomFilterEnabledFor, but takes a ColumnPath +func WithBloomFilterEnabledPath(path ColumnPath, enabled bool) WriterProperty { + return WithBloomFilterEnabledFor(path.String(), enabled) +} + +// WithBloomFilterFPP sets the default value for the false positive probability for writing +// bloom filters. +func WithBloomFilterFPP(fpp float64) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.wr.defColumnProps.BloomFilterFPP = fpp + } +} + +// WithBloomFilterFPPFor specifies a per column value for the false positive probability +// for writing bloom filters. +func WithBloomFilterFPPFor(path string, fpp float64) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.bloomFilterFPPs[path] = fpp + } +} + +// WithBloomFilterFPPPath is like WithBloomFilterFPPFor, but takes a ColumnPath +func WithBloomFilterFPPPath(path ColumnPath, fpp float64) WriterProperty { + return WithBloomFilterFPPFor(path.String(), fpp) +} + +// WithAdaptiveBloomFilterEnabled sets the default value for whether to enable writing +// adaptive bloom filters for columns. This is the default value for all columns, +// but can be overridden by using WithAdaptiveBloomFilterEnabledFor or +// WithAdaptiveBloomFilterEnabledPath. +// +// Using an Adaptive Bloom filter will attempt to use multiple candidate bloom filters +// when building the column, with different expected distinct values. It will attempt +// to use the smallest candidate bloom filter that achieves the desired false positive +// probability. Dropping candidates bloom filters that are no longer viable. +func WithAdaptiveBloomFilterEnabled(enabled bool) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.wr.defColumnProps.AdaptiveBloomFilterEnabled = enabled + } +} + +// WithAdaptiveBloomFilterEnabledFor specifies a per column value as to enable or disable writing +// adaptive bloom filters for the column. +func WithAdaptiveBloomFilterEnabledFor(path string, enabled bool) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.adaptiveBloomFilterEnabled[path] = enabled + } +} + +// WithAdaptiveBloomFilterEnabledPath is like WithAdaptiveBloomFilterEnabledFor, but takes a ColumnPath +func WithAdaptiveBloomFilterEnabledPath(path ColumnPath, enabled bool) WriterProperty { + return WithAdaptiveBloomFilterEnabledFor(path.String(), enabled) +} + +// WithBloomFilterCandidates sets the number of candidate bloom filters to use when building +// an adaptive bloom filter. +func WithBloomFilterCandidates(candidates int) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.wr.defColumnProps.BloomFilterCandidates = candidates + } +} + +// WithBloomFilterCandidatesFor specifies a per column value for the number of candidate +// bloom filters to use when building an adaptive bloom filter. +func WithBloomFilterCandidatesFor(path string, candidates int) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.numBloomFilterCandidates[path] = candidates + } +} + +// WithBloomFilterCandidatesPath is like WithBloomFilterCandidatesFor, but takes a ColumnPath +func WithBloomFilterCandidatesPath(path ColumnPath, candidates int) WriterProperty { + return WithBloomFilterCandidatesFor(path.String(), candidates) +} + +// WithBloomFilterNDV sets the default value for the expected number of distinct values +// to be written for the column. This is ignored when using adaptive bloom filters. +func WithBloomFilterNDV(ndv int64) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.wr.defColumnProps.BloomFilterNDV = ndv + } +} + +// WithBloomFilterNDVFor specifies a per column value for the expected number of distinct values +// to be written for the column. This is ignored when using adaptive bloom filters. +func WithBloomFilterNDVFor(path string, ndv int64) WriterProperty { + return func(cfg *writerPropConfig) { + cfg.bloomFilterNDVs[path] = ndv + } +} + +// WithBloomFilterNDVPath is like WithBloomFilterNDVFor, but takes a ColumnPath +func WithBloomFilterNDVPath(path ColumnPath, ndv int64) WriterProperty { + return WithBloomFilterNDVFor(path.String(), ndv) +} + // WriterProperties is the collection of properties to use for writing a parquet file. The values are // read only once it has been constructed. type WriterProperties struct { - mem memory.Allocator - dictPagesize int64 - batchSize int64 - maxRowGroupLen int64 - pageSize int64 - parquetVersion Version - createdBy string - dataPageVersion DataPageVersion - rootName string - rootRepetition Repetition - storeDecimalAsInt bool + mem memory.Allocator + dictPagesize int64 + batchSize int64 + maxRowGroupLen int64 + pageSize int64 + parquetVersion Version + createdBy string + dataPageVersion DataPageVersion + rootName string + rootRepetition Repetition + storeDecimalAsInt bool + maxBloomFilterBytes int64 defColumnProps ColumnProperties columnProps map[string]*ColumnProperties @@ -360,18 +507,19 @@ type WriterProperties struct { func defaultWriterProperties() *WriterProperties { return &WriterProperties{ - mem: memory.DefaultAllocator, - dictPagesize: DefaultDictionaryPageSizeLimit, - batchSize: DefaultWriteBatchSize, - maxRowGroupLen: DefaultMaxRowGroupLen, - pageSize: DefaultDataPageSize, - parquetVersion: V2_LATEST, - dataPageVersion: DataPageV1, - createdBy: DefaultCreatedBy, - rootName: DefaultRootName, - rootRepetition: Repetitions.Repeated, - defColumnProps: DefaultColumnProperties(), - sortingCols: []SortingColumn{}, + mem: memory.DefaultAllocator, + dictPagesize: DefaultDictionaryPageSizeLimit, + batchSize: DefaultWriteBatchSize, + maxRowGroupLen: DefaultMaxRowGroupLen, + pageSize: DefaultDataPageSize, + parquetVersion: V2_LATEST, + dataPageVersion: DataPageV1, + createdBy: DefaultCreatedBy, + rootName: DefaultRootName, + rootRepetition: Repetitions.Repeated, + maxBloomFilterBytes: DefaultMaxBloomFilterBytes, + defColumnProps: DefaultColumnProperties(), + sortingCols: []SortingColumn{}, } } @@ -381,23 +529,28 @@ func defaultWriterProperties() *WriterProperties { // // The Default properties use the following constants: // -// Allocator: memory.DefaultAllocator +// Allocator: memory.DefaultAllocator // DictionaryPageSize: DefaultDictionaryPageSizeLimit -// BatchSize: DefaultWriteBatchSize -// MaxRowGroupLength: DefaultMaxRowGroupLen -// PageSize: DefaultDataPageSize -// ParquetVersion: V1 -// DataPageVersion: DataPageV1 -// CreatedBy: DefaultCreatedBy +// BatchSize: DefaultWriteBatchSize +// MaxRowGroupLength: DefaultMaxRowGroupLen +// PageSize: DefaultDataPageSize +// ParquetVersion: V2_LATEST +// DataPageVersion: DataPageV1 +// CreatedBy: DefaultCreatedBy func NewWriterProperties(opts ...WriterProperty) *WriterProperties { cfg := writerPropConfig{ - wr: defaultWriterProperties(), - encodings: make(map[string]Encoding), - codecs: make(map[string]compress.Compression), - compressLevel: make(map[string]int), - dictEnabled: make(map[string]bool), - statsEnabled: make(map[string]bool), - indexEnabled: make(map[string]bool), + wr: defaultWriterProperties(), + encodings: make(map[string]Encoding), + codecs: make(map[string]compress.Compression), + compressLevel: make(map[string]int), + dictEnabled: make(map[string]bool), + statsEnabled: make(map[string]bool), + indexEnabled: make(map[string]bool), + bloomFilterNDVs: make(map[string]int64), + bloomFilterFPPs: make(map[string]float64), + bloomFilterEnabled: make(map[string]bool), + adaptiveBloomFilterEnabled: make(map[string]bool), + numBloomFilterCandidates: make(map[string]int), } for _, o := range opts { o(&cfg) @@ -436,6 +589,27 @@ func NewWriterProperties(opts ...WriterProperty) *WriterProperties { for key, value := range cfg.indexEnabled { get(key).PageIndexEnabled = value } + + for key, value := range cfg.bloomFilterEnabled { + get(key).BloomFilterEnabled = value + } + + for key, value := range cfg.bloomFilterFPPs { + get(key).BloomFilterFPP = value + } + + for key, value := range cfg.bloomFilterNDVs { + get(key).BloomFilterNDV = value + } + + for key, value := range cfg.adaptiveBloomFilterEnabled { + get(key).AdaptiveBloomFilterEnabled = value + } + + for key, value := range cfg.numBloomFilterCandidates { + get(key).BloomFilterCandidates = value + } + return cfg.wr } @@ -613,3 +787,98 @@ func (w *WriterProperties) ColumnEncryptionProperties(path string) *ColumnEncryp func (w *WriterProperties) StoreDecimalAsInteger() bool { return w.storeDecimalAsInt } + +// MaxBloomFilterBytes returns the maximum number of bytes that a bloom filter can use +func (w *WriterProperties) MaxBloomFilterBytes() int64 { + return w.maxBloomFilterBytes +} + +// BloomFilterEnabled returns the default value for whether or not bloom filters are enabled +func (w *WriterProperties) BloomFilterEnabled() bool { + return w.defColumnProps.BloomFilterEnabled +} + +// BloomFilterEnabledFor returns whether or not bloom filters are enabled for the given column path +func (w *WriterProperties) BloomFilterEnabledFor(path string) bool { + if p, ok := w.columnProps[path]; ok { + return p.BloomFilterEnabled + } + return w.defColumnProps.BloomFilterEnabled +} + +// BloomFilterEnabledPath is the same as BloomFilterEnabledFor but takes a ColumnPath +func (w *WriterProperties) BloomFilterEnabledPath(path ColumnPath) bool { + return w.BloomFilterEnabledFor(path.String()) +} + +// BloomFilterFPP returns the default false positive probability for bloom filters +func (w *WriterProperties) BloomFilterFPP() float64 { + return w.defColumnProps.BloomFilterFPP +} + +// BloomFilterFPPFor returns the false positive probability for the given column path +func (w *WriterProperties) BloomFilterFPPFor(path string) float64 { + if p, ok := w.columnProps[path]; ok { + return p.BloomFilterFPP + } + return w.defColumnProps.BloomFilterFPP +} + +// BloomFilterFPPPath is the same as BloomFilterFPPFor but takes a ColumnPath +func (w *WriterProperties) BloomFilterFPPPath(path ColumnPath) float64 { + return w.BloomFilterFPPFor(path.String()) +} + +// AdaptiveBloomFilterEnabled returns the default value for whether or not adaptive bloom filters are enabled +func (w *WriterProperties) AdaptiveBloomFilterEnabled() bool { + return w.defColumnProps.AdaptiveBloomFilterEnabled +} + +// AdaptiveBloomFilterEnabledFor returns whether or not adaptive bloom filters are enabled for the given column path +func (w *WriterProperties) AdaptiveBloomFilterEnabledFor(path string) bool { + if p, ok := w.columnProps[path]; ok { + return p.AdaptiveBloomFilterEnabled + } + return w.defColumnProps.AdaptiveBloomFilterEnabled +} + +// AdaptiveBloomFilterEnabledPath is the same as AdaptiveBloomFilterEnabledFor but takes a ColumnPath +func (w *WriterProperties) AdaptiveBloomFilterEnabledPath(path ColumnPath) bool { + return w.AdaptiveBloomFilterEnabledFor(path.String()) +} + +// BloomFilterCandidates returns the default number of candidates to use for bloom filters +func (w *WriterProperties) BloomFilterCandidates() int { + return w.defColumnProps.BloomFilterCandidates +} + +// BloomFilterCandidatesFor returns the number of candidates to use for the given column path +func (w *WriterProperties) BloomFilterCandidatesFor(path string) int { + if p, ok := w.columnProps[path]; ok { + return p.BloomFilterCandidates + } + return w.defColumnProps.BloomFilterCandidates +} + +// BloomFilterCandidatesPath is the same as BloomFilterCandidatesFor but takes a ColumnPath +func (w *WriterProperties) BloomFilterCandidatesPath(path ColumnPath) int { + return w.BloomFilterCandidatesFor(path.String()) +} + +// BloomFilterNDV returns the default number of distinct values to use for bloom filters +func (w *WriterProperties) BloomFilterNDV() int64 { + return w.defColumnProps.BloomFilterNDV +} + +// BloomFilterNDVFor returns the number of distinct values to use for the given column path +func (w *WriterProperties) BloomFilterNDVFor(path string) int64 { + if p, ok := w.columnProps[path]; ok { + return p.BloomFilterNDV + } + return w.defColumnProps.BloomFilterNDV +} + +// BloomFilterNDVPath is the same as BloomFilterNDVFor but takes a ColumnPath +func (w *WriterProperties) BloomFilterNDVPath(path ColumnPath) int64 { + return w.BloomFilterNDVFor(path.String()) +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy