Avoid O(N^2) in VALUES with ordinals grouping (#130576)

Using the VALUES aggregator with ordinals grouping led to accidental 
quadratic complexity. Queries like FROM .. | STATS ... VALUES(field) ...
BY keyword-field are affected by this performance issue. This change
caches a sorted structure - previously used to fix a similar O(N^2)
problem when emitting the output block - during the merging phase of the
OrdinalGroupingOperator.
This commit is contained in:
Nhat Nguyen 2025-07-07 14:51:29 -07:00 committed by GitHub
parent 02b2f5eb66
commit 59df1bfd51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 570 additions and 221 deletions

View File

@ -95,7 +95,8 @@ public class ValuesAggregatorBenchmark {
try {
for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) {
for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) {
run(Integer.parseInt(groups), dataType, 10);
run(Integer.parseInt(groups), dataType, 10, 0);
run(Integer.parseInt(groups), dataType, 10, 1);
}
}
} catch (NoSuchFieldException e) {
@ -113,7 +114,10 @@ public class ValuesAggregatorBenchmark {
@Param({ BYTES_REF, INT, LONG })
public String dataType;
private static Operator operator(DriverContext driverContext, int groups, String dataType) {
@Param({ "0", "1" })
public int numOrdinalMerges;
private static Operator operator(DriverContext driverContext, int groups, String dataType, int numOrdinalMerges) {
if (groups == 1) {
return new AggregationOperator(
List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
@ -125,7 +129,24 @@ public class ValuesAggregatorBenchmark {
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
driverContext
);
) {
@Override
public Page getOutput() {
mergeOrdinal();
return super.getOutput();
}
// simulate OrdinalsGroupingOperator
void mergeOrdinal() {
var merged = supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1)).apply(driverContext);
for (int i = 0; i < numOrdinalMerges; i++) {
for (int p = 0; p < groups; p++) {
merged.addIntermediateRow(p, aggregators.getFirst(), p);
}
}
aggregators.set(0, merged);
}
};
}
private static AggregatorFunctionSupplier supplier(String dataType) {
@ -331,12 +352,12 @@ public class ValuesAggregatorBenchmark {
@Benchmark
public void run() {
run(groups, dataType, OP_COUNT);
run(groups, dataType, OP_COUNT, numOrdinalMerges);
}
private static void run(int groups, String dataType, int opCount) {
private static void run(int groups, String dataType, int opCount, int numOrdinalMerges) {
DriverContext driverContext = driverContext();
try (Operator operator = operator(driverContext, groups, dataType)) {
try (Operator operator = operator(driverContext, groups, dataType, numOrdinalMerges)) {
Page page = page(groups, dataType);
for (int i = 0; i < opCount; i++) {
operator.addInput(page.shallowCopy());

View File

@ -0,0 +1,5 @@
pr: 130576
summary: Avoid O(N^2) in VALUES with ordinals grouping
area: ES|QL
type: bug
issues: []

View File

@ -24,6 +24,7 @@ import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
@ -55,8 +56,8 @@ class ValuesBytesRefAggregator {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays) {
return new GroupingState(bigArrays);
public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(driverContext);
}
public static GroupingAggregatorFunction.AddInput wrapAddInput(
@ -76,7 +77,7 @@ class ValuesBytesRefAggregator {
}
public static void combine(GroupingState state, int groupId, BytesRef v) {
state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v)));
state.addValue(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) {
@ -84,17 +85,20 @@ class ValuesBytesRefAggregator {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.getBytesRef(i, scratch));
state.addValue(groupId, values.getBytesRef(i, scratch));
}
}
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
BytesRef scratch = new BytesRef();
for (int id = 0; id < state.values.size(); id++) {
if (state.values.getKey1(id) == statePosition) {
long value = state.values.getKey2(id);
combine(current, currentGroupId, state.bytes.get(value, scratch));
}
var sorted = state.sortedForOrdinalMerging(current);
if (statePosition > state.maxGroupId) {
return;
}
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValueOrdinal(currentGroupId, id);
}
}
@ -138,6 +142,22 @@ class ValuesBytesRefAggregator {
}
}
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
* If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)},
* these are ordinals referring to the {@link GroupingState#bytes} in the target state.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/**
* State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering
@ -146,15 +166,20 @@ class ValuesBytesRefAggregator {
* collector operation. But at least it's fairly simple.
*/
public static class GroupingState implements GroupingAggregatorState {
final LongLongHash values;
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongLongHash values;
BytesRefHash bytes;
private GroupingState(BigArrays bigArrays) {
private Sorted sortedForOrdinalMerging = null;
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
LongLongHash _values = null;
BytesRefHash _bytes = null;
try {
_values = new LongLongHash(1, bigArrays);
_bytes = new BytesRefHash(1, bigArrays);
_values = new LongLongHash(1, driverContext.bigArrays());
_bytes = new BytesRefHash(1, driverContext.bigArrays());
values = _values;
bytes = _bytes;
@ -171,6 +196,16 @@ class ValuesBytesRefAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
void addValueOrdinal(int groupId, long valueOrdinal) {
values.add(groupId, valueOrdinal);
maxGroupId = Math.max(maxGroupId, groupId);
}
void addValue(int groupId, BytesRef v) {
values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v)));
maxGroupId = Math.max(maxGroupId, groupId);
}
/**
* Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg.
@ -180,8 +215,19 @@ class ValuesBytesRefAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
try (var sorted = buildSorted(selected)) {
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
} else {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0;
long idsSize = 0;
Sorted sorted = null;
try {
/*
* Get a count of all groups less than the maximum selected group. Count
@ -256,16 +302,44 @@ class ValuesBytesRefAggregator {
ids[selectedCounts[group]++] = id;
}
}
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids);
} else {
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
}
final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
}
}
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
// hash all the bytes to the destination to avoid hashing them multiple times
BytesRef scratch = new BytesRef();
final int totalValue = Math.toIntExact(bytes.size());
blockFactory.adjustBreaker((long) totalValue * Integer.BYTES);
try {
final int[] mappedIds = new int[totalValue];
for (int i = 0; i < totalValue; i++) {
var v = bytes.get(i, scratch);
mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v)));
}
// no longer need the bytes
bytes.close();
bytes = null;
for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) {
sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))];
}
} finally {
blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES);
}
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/*
* Insert the ids in order.
@ -279,11 +353,11 @@ class ValuesBytesRefAggregator {
int count = end - start;
switch (count) {
case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start], scratch);
case 1 -> builder.appendBytesRef(getValue(ids[start], scratch));
default -> {
builder.beginPositionEntry();
for (int i = start; i < end; i++) {
append(builder, ids[i], scratch);
builder.appendBytesRef(getValue(ids[i], scratch));
}
builder.endPositionEntry();
}
@ -331,9 +405,8 @@ class ValuesBytesRefAggregator {
}
}
private void append(BytesRefBlock.Builder builder, int id, BytesRef scratch) {
BytesRef value = bytes.get(values.getKey2(id), scratch);
builder.appendBytesRef(value);
BytesRef getValue(int valueId, BytesRef scratch) {
return bytes.get(values.getKey2(valueId), scratch);
}
@Override
@ -343,7 +416,7 @@ class ValuesBytesRefAggregator {
@Override
public void close() {
Releasables.closeExpectNoException(values, bytes);
Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging);
}
}
}

View File

@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
* Aggregates field values for double.
@ -48,28 +50,32 @@ class ValuesDoubleAggregator {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays) {
return new GroupingState(bigArrays);
public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(driverContext);
}
public static void combine(GroupingState state, int groupId, double v) {
state.values.add(groupId, Double.doubleToLongBits(v));
state.addValue(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.getDouble(i));
state.addValue(groupId, values.getDouble(i));
}
}
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
for (int id = 0; id < state.values.size(); id++) {
if (state.values.getKey1(id) == statePosition) {
double value = Double.longBitsToDouble(state.values.getKey2(id));
combine(current, currentGroupId, value);
}
var sorted = state.sortedForOrdinalMerging(current);
if (statePosition > state.maxGroupId) {
return;
}
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValue(currentGroupId, state.getValue(id));
}
}
@ -112,6 +118,20 @@ class ValuesDoubleAggregator {
}
}
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/**
* State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering
@ -120,10 +140,15 @@ class ValuesDoubleAggregator {
* collector operation. But at least it's fairly simple.
*/
public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongLongHash values;
private GroupingState(BigArrays bigArrays) {
values = new LongLongHash(1, bigArrays);
private Sorted sortedForOrdinalMerging = null;
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
values = new LongLongHash(1, driverContext.bigArrays());
}
@Override
@ -131,6 +156,11 @@ class ValuesDoubleAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
void addValue(int groupId, double v) {
values.add(groupId, Double.doubleToLongBits(v));
maxGroupId = Math.max(maxGroupId, groupId);
}
/**
* Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg.
@ -140,8 +170,15 @@ class ValuesDoubleAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
try (var sorted = buildSorted(selected)) {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0;
long idsSize = 0;
Sorted sorted = null;
try {
/*
* Get a count of all groups less than the maximum selected group. Count
@ -216,12 +253,25 @@ class ValuesDoubleAggregator {
ids[selectedCounts[group]++] = id;
}
}
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
}
}
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/*
* Insert the ids in order.
@ -234,11 +284,11 @@ class ValuesDoubleAggregator {
int count = end - start;
switch (count) {
case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]);
case 1 -> builder.appendDouble(getValue(ids[start]));
default -> {
builder.beginPositionEntry();
for (int i = start; i < end; i++) {
append(builder, ids[i]);
builder.appendDouble(getValue(ids[i]));
}
builder.endPositionEntry();
}
@ -249,9 +299,8 @@ class ValuesDoubleAggregator {
}
}
private void append(DoubleBlock.Builder builder, int id) {
double value = Double.longBitsToDouble(values.getKey2(id));
builder.appendDouble(value);
double getValue(int valueId) {
return Double.longBitsToDouble(values.getKey2(valueId));
}
@Override
@ -261,7 +310,7 @@ class ValuesDoubleAggregator {
@Override
public void close() {
values.close();
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
}
}
}

View File

@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
* Aggregates field values for float.
@ -47,34 +49,32 @@ class ValuesFloatAggregator {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays) {
return new GroupingState(bigArrays);
public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(driverContext);
}
public static void combine(GroupingState state, int groupId, float v) {
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
state.addValue(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.getFloat(i));
state.addValue(groupId, values.getFloat(i));
}
}
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
for (int id = 0; id < state.values.size(); id++) {
long both = state.values.get(id);
int group = (int) (both >>> Float.SIZE);
if (group == statePosition) {
float value = Float.intBitsToFloat((int) both);
combine(current, currentGroupId, value);
}
var sorted = state.sortedForOrdinalMerging(current);
if (statePosition > state.maxGroupId) {
return;
}
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValue(currentGroupId, state.getValue(id));
}
}
@ -117,6 +117,20 @@ class ValuesFloatAggregator {
}
}
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/**
* State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering
@ -125,10 +139,15 @@ class ValuesFloatAggregator {
* collector operation. But at least it's fairly simple.
*/
public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongHash values;
private GroupingState(BigArrays bigArrays) {
values = new LongHash(1, bigArrays);
private Sorted sortedForOrdinalMerging = null;
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
values = new LongHash(1, driverContext.bigArrays());
}
@Override
@ -136,6 +155,15 @@ class ValuesFloatAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
void addValue(int groupId, float v) {
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
maxGroupId = Math.max(maxGroupId, groupId);
}
/**
* Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg.
@ -145,8 +173,15 @@ class ValuesFloatAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
try (var sorted = buildSorted(selected)) {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0;
long idsSize = 0;
Sorted sorted = null;
try {
/*
* Get a count of all groups less than the maximum selected group. Count
@ -223,12 +258,25 @@ class ValuesFloatAggregator {
ids[selectedCounts[group]++] = id;
}
}
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
}
}
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/*
* Insert the ids in order.
@ -241,11 +289,11 @@ class ValuesFloatAggregator {
int count = end - start;
switch (count) {
case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]);
case 1 -> builder.appendFloat(getValue(ids[start]));
default -> {
builder.beginPositionEntry();
for (int i = start; i < end; i++) {
append(builder, ids[i]);
builder.appendFloat(getValue(ids[i]));
}
builder.endPositionEntry();
}
@ -256,10 +304,9 @@ class ValuesFloatAggregator {
}
}
private void append(FloatBlock.Builder builder, int id) {
long both = values.get(id);
float value = Float.intBitsToFloat((int) both);
builder.appendFloat(value);
float getValue(int valueId) {
long both = values.get(valueId);
return Float.intBitsToFloat((int) both);
}
@Override
@ -269,7 +316,7 @@ class ValuesFloatAggregator {
@Override
public void close() {
values.close();
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
}
}
}

View File

@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
* Aggregates field values for int.
@ -47,34 +49,32 @@ class ValuesIntAggregator {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays) {
return new GroupingState(bigArrays);
public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(driverContext);
}
public static void combine(GroupingState state, int groupId, int v) {
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
state.addValue(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.getInt(i));
state.addValue(groupId, values.getInt(i));
}
}
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
for (int id = 0; id < state.values.size(); id++) {
long both = state.values.get(id);
int group = (int) (both >>> Integer.SIZE);
if (group == statePosition) {
int value = (int) both;
combine(current, currentGroupId, value);
}
var sorted = state.sortedForOrdinalMerging(current);
if (statePosition > state.maxGroupId) {
return;
}
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValue(currentGroupId, state.getValue(id));
}
}
@ -117,6 +117,20 @@ class ValuesIntAggregator {
}
}
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/**
* State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering
@ -125,10 +139,15 @@ class ValuesIntAggregator {
* collector operation. But at least it's fairly simple.
*/
public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongHash values;
private GroupingState(BigArrays bigArrays) {
values = new LongHash(1, bigArrays);
private Sorted sortedForOrdinalMerging = null;
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
values = new LongHash(1, driverContext.bigArrays());
}
@Override
@ -136,6 +155,15 @@ class ValuesIntAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
void addValue(int groupId, int v) {
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
maxGroupId = Math.max(maxGroupId, groupId);
}
/**
* Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg.
@ -145,8 +173,15 @@ class ValuesIntAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
try (var sorted = buildSorted(selected)) {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0;
long idsSize = 0;
Sorted sorted = null;
try {
/*
* Get a count of all groups less than the maximum selected group. Count
@ -223,12 +258,25 @@ class ValuesIntAggregator {
ids[selectedCounts[group]++] = id;
}
}
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
}
}
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/*
* Insert the ids in order.
@ -241,11 +289,11 @@ class ValuesIntAggregator {
int count = end - start;
switch (count) {
case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]);
case 1 -> builder.appendInt(getValue(ids[start]));
default -> {
builder.beginPositionEntry();
for (int i = start; i < end; i++) {
append(builder, ids[i]);
builder.appendInt(getValue(ids[i]));
}
builder.endPositionEntry();
}
@ -256,10 +304,9 @@ class ValuesIntAggregator {
}
}
private void append(IntBlock.Builder builder, int id) {
long both = values.get(id);
int value = (int) both;
builder.appendInt(value);
int getValue(int valueId) {
long both = values.get(valueId);
return (int) both;
}
@Override
@ -269,7 +316,7 @@ class ValuesIntAggregator {
@Override
public void close() {
values.close();
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
}
}
}

View File

@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
* Aggregates field values for long.
@ -48,28 +50,32 @@ class ValuesLongAggregator {
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays) {
return new GroupingState(bigArrays);
public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(driverContext);
}
public static void combine(GroupingState state, int groupId, long v) {
state.values.add(groupId, v);
state.addValue(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
combine(state, groupId, values.getLong(i));
state.addValue(groupId, values.getLong(i));
}
}
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
for (int id = 0; id < state.values.size(); id++) {
if (state.values.getKey1(id) == statePosition) {
long value = state.values.getKey2(id);
combine(current, currentGroupId, value);
}
var sorted = state.sortedForOrdinalMerging(current);
if (statePosition > state.maxGroupId) {
return;
}
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValue(currentGroupId, state.getValue(id));
}
}
@ -112,6 +118,20 @@ class ValuesLongAggregator {
}
}
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/**
* State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering
@ -120,10 +140,15 @@ class ValuesLongAggregator {
* collector operation. But at least it's fairly simple.
*/
public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongLongHash values;
private GroupingState(BigArrays bigArrays) {
values = new LongLongHash(1, bigArrays);
private Sorted sortedForOrdinalMerging = null;
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
values = new LongLongHash(1, driverContext.bigArrays());
}
@Override
@ -131,6 +156,11 @@ class ValuesLongAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
void addValue(int groupId, long v) {
values.add(groupId, v);
maxGroupId = Math.max(maxGroupId, groupId);
}
/**
* Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg.
@ -140,8 +170,15 @@ class ValuesLongAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
try (var sorted = buildSorted(selected)) {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0;
long idsSize = 0;
Sorted sorted = null;
try {
/*
* Get a count of all groups less than the maximum selected group. Count
@ -216,12 +253,25 @@ class ValuesLongAggregator {
ids[selectedCounts[group]++] = id;
}
}
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
}
}
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/*
* Insert the ids in order.
@ -234,11 +284,11 @@ class ValuesLongAggregator {
int count = end - start;
switch (count) {
case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]);
case 1 -> builder.appendLong(getValue(ids[start]));
default -> {
builder.beginPositionEntry();
for (int i = start; i < end; i++) {
append(builder, ids[i]);
builder.appendLong(getValue(ids[i]));
}
builder.endPositionEntry();
}
@ -249,9 +299,8 @@ class ValuesLongAggregator {
}
}
private void append(LongBlock.Builder builder, int id) {
long value = values.getKey2(id);
builder.appendLong(value);
long getValue(int valueId) {
return values.getKey2(valueId);
}
@Override
@ -261,7 +310,7 @@ class ValuesLongAggregator {
@Override
public void close() {
values.close();
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
}
}
}

View File

@ -43,7 +43,7 @@ public final class ValuesBytesRefGroupingAggregatorFunction implements GroupingA
public static ValuesBytesRefGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) {
return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext.bigArrays()), driverContext);
return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext), driverContext);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -42,7 +42,7 @@ public final class ValuesDoubleGroupingAggregatorFunction implements GroupingAgg
public static ValuesDoubleGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) {
return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext);
return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext), driverContext);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -42,7 +42,7 @@ public final class ValuesFloatGroupingAggregatorFunction implements GroupingAggr
public static ValuesFloatGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) {
return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext);
return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext), driverContext);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -41,7 +41,7 @@ public final class ValuesIntGroupingAggregatorFunction implements GroupingAggreg
public static ValuesIntGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) {
return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext.bigArrays()), driverContext);
return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext), driverContext);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -42,7 +42,7 @@ public final class ValuesLongGroupingAggregatorFunction implements GroupingAggre
public static ValuesLongGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) {
return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext.bigArrays()), driverContext);
return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext), driverContext);
}
public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -55,7 +55,7 @@ final class ValuesBytesRefAggregators {
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
}
}
}
@ -77,7 +77,7 @@ final class ValuesBytesRefAggregators {
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
}
}
}
@ -93,7 +93,7 @@ final class ValuesBytesRefAggregators {
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) {
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
}
}
}
@ -135,7 +135,7 @@ final class ValuesBytesRefAggregators {
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = groupIds.getInt(g);
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
}
}
}
@ -150,7 +150,7 @@ final class ValuesBytesRefAggregators {
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) {
int groupId = groupIds.getInt(g);
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
}
}
}
@ -159,7 +159,7 @@ final class ValuesBytesRefAggregators {
public void add(int positionOffset, IntVector groupIds) {
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
int groupId = groupIds.getInt(groupPosition);
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
}
}

View File

@ -43,12 +43,9 @@ $if(BytesRef)$
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
$endif$
import org.elasticsearch.compute.operator.DriverContext;
$if(BytesRef)$
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
$else$
$endif$
/**
* Aggregates field values for $type$.
* This class is generated. Edit @{code X-ValuesAggregator.java.st} instead
@ -90,8 +87,8 @@ $endif$
return state.toBlock(driverContext.blockFactory());
}
public static GroupingState initGrouping(BigArrays bigArrays) {
return new GroupingState(bigArrays);
public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(driverContext);
}
$if(BytesRef)$
@ -113,25 +110,7 @@ $if(BytesRef)$
$endif$
public static void combine(GroupingState state, int groupId, $type$ v) {
$if(long)$
state.values.add(groupId, v);
$elseif(double)$
state.values.add(groupId, Double.doubleToLongBits(v));
$elseif(BytesRef)$
state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v)));
$elseif(int)$
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
$elseif(float)$
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
$endif$
state.addValue(groupId, v);
}
public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
@ -142,37 +121,27 @@ $endif$
int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) {
$if(BytesRef)$
combine(state, groupId, values.getBytesRef(i, scratch));
state.addValue(groupId, values.getBytesRef(i, scratch));
$else$
combine(state, groupId, values.get$Type$(i));
state.addValue(groupId, values.get$Type$(i));
$endif$
}
}
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
var sorted = state.sortedForOrdinalMerging(current);
if (statePosition > state.maxGroupId) {
return;
}
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
$if(BytesRef)$
BytesRef scratch = new BytesRef();
current.addValueOrdinal(currentGroupId, id);
$else$
current.addValue(currentGroupId, state.getValue(id));
$endif$
for (int id = 0; id < state.values.size(); id++) {
$if(long||BytesRef)$
if (state.values.getKey1(id) == statePosition) {
long value = state.values.getKey2(id);
$elseif(double)$
if (state.values.getKey1(id) == statePosition) {
double value = Double.longBitsToDouble(state.values.getKey2(id));
$elseif(int)$
long both = state.values.get(id);
int group = (int) (both >>> Integer.SIZE);
if (group == statePosition) {
int value = (int) both;
$elseif(float)$
long both = state.values.get(id);
int group = (int) (both >>> Float.SIZE);
if (group == statePosition) {
float value = Float.intBitsToFloat((int) both);
$endif$
combine(current, currentGroupId, $if(BytesRef)$state.bytes.get(value, scratch)$else$value$endif$);
}
}
}
@ -247,6 +216,24 @@ $endif$
}
}
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
$if(BytesRef)$
* If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)},
* these are ordinals referring to the {@link GroupingState#bytes} in the target state.
$endif$
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/**
* State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering
@ -255,26 +242,31 @@ $endif$
* collector operation. But at least it's fairly simple.
*/
public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
$if(long||double)$
private final LongLongHash values;
$elseif(BytesRef)$
final LongLongHash values;
private final LongLongHash values;
BytesRefHash bytes;
$elseif(int||float)$
private final LongHash values;
$endif$
private GroupingState(BigArrays bigArrays) {
private Sorted sortedForOrdinalMerging = null;
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
$if(long||double)$
values = new LongLongHash(1, bigArrays);
values = new LongLongHash(1, driverContext.bigArrays());
$elseif(BytesRef)$
LongLongHash _values = null;
BytesRefHash _bytes = null;
try {
_values = new LongLongHash(1, bigArrays);
_bytes = new BytesRefHash(1, bigArrays);
_values = new LongLongHash(1, driverContext.bigArrays());
_bytes = new BytesRefHash(1, driverContext.bigArrays());
values = _values;
bytes = _bytes;
@ -285,7 +277,7 @@ $elseif(BytesRef)$
Releasables.closeExpectNoException(_values, _bytes);
}
$elseif(int||float)$
values = new LongHash(1, bigArrays);
values = new LongHash(1, driverContext.bigArrays());
$endif$
}
@ -294,6 +286,36 @@ $endif$
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
$if(BytesRef)$
void addValueOrdinal(int groupId, long valueOrdinal) {
values.add(groupId, valueOrdinal);
maxGroupId = Math.max(maxGroupId, groupId);
}
$endif$
void addValue(int groupId, $type$ v) {
$if(long)$
values.add(groupId, v);
$elseif(double)$
values.add(groupId, Double.doubleToLongBits(v));
$elseif(BytesRef)$
values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v)));
$elseif(int)$
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
$elseif(float)$
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
$endif$
maxGroupId = Math.max(maxGroupId, groupId);
}
/**
* Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg.
@ -303,8 +325,23 @@ $endif$
return blockFactory.newConstantNullBlock(selected.getPositionCount());
}
try (var sorted = buildSorted(selected)) {
$if(BytesRef)$
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
} else {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
$else$
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
$endif$
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0;
long idsSize = 0;
Sorted sorted = null;
try {
/*
* Get a count of all groups less than the maximum selected group. Count
@ -379,30 +416,56 @@ $endif$
idsSize = adjust;
int[] ids = new int[total];
for (int id = 0; id < values.size(); id++) {
$if(long||BytesRef||double)$
$if(long||BytesRef||double)$
int group = (int) values.getKey1(id);
$elseif(float||int)$
$elseif(float||int)$
long both = values.get(id);
int group = (int) (both >>> Float.SIZE);
$endif$
$endif$
if (group < selectedCounts.length && selectedCounts[group] >= 0) {
ids[selectedCounts[group]++] = id;
}
}
$if(BytesRef)$
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids);
} else {
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
}
$else$
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
$endif$
final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
}
}
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
$if(BytesRef)$
// hash all the bytes to the destination to avoid hashing them multiple times
BytesRef scratch = new BytesRef();
final int totalValue = Math.toIntExact(bytes.size());
blockFactory.adjustBreaker((long) totalValue * Integer.BYTES);
try {
final int[] mappedIds = new int[totalValue];
for (int i = 0; i < totalValue; i++) {
var v = bytes.get(i, scratch);
mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v)));
}
// no longer need the bytes
bytes.close();
bytes = null;
for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) {
sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))];
}
} finally {
blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES);
}
$endif$
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/*
* Insert the ids in order.
@ -418,11 +481,11 @@ $endif$
int count = end - start;
switch (count) {
case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]$if(BytesRef)$, scratch$endif$);
case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$));
default -> {
builder.beginPositionEntry();
for (int i = start; i < end; i++) {
append(builder, ids[i]$if(BytesRef)$, scratch$endif$);
builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$));
}
builder.endPositionEntry();
}
@ -470,29 +533,24 @@ $if(BytesRef)$
}
}
}
$endif$
private void append($Type$Block.Builder builder, int id, BytesRef scratch) {
BytesRef value = bytes.get(values.getKey2(id), scratch);
builder.appendBytesRef(value);
}
$else$
private void append($Type$Block.Builder builder, int id) {
$if(long)$
long value = values.getKey2(id);
$type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) {
$if(BytesRef)$
return bytes.get(values.getKey2(valueId), scratch);
$elseif(long)$
return values.getKey2(valueId);
$elseif(double)$
double value = Double.longBitsToDouble(values.getKey2(id));
return Double.longBitsToDouble(values.getKey2(valueId));
$elseif(float)$
long both = values.get(id);
float value = Float.intBitsToFloat((int) both);
long both = values.get(valueId);
return Float.intBitsToFloat((int) both);
$elseif(int)$
long both = values.get(id);
int value = (int) both;
long both = values.get(valueId);
return (int) both;
$endif$
builder.append$Type$(value);
}
$endif$
@Override
public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
@ -501,9 +559,9 @@ $endif$
@Override
public void close() {
$if(BytesRef)$
Releasables.closeExpectNoException(values, bytes);
Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging);
$else$
values.close();
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
$endif$
}
}

View File

@ -85,7 +85,7 @@ public class HashAggregationOperator implements Operator {
private final BlockHash blockHash;
private final List<GroupingAggregator> aggregators;
protected final List<GroupingAggregator> aggregators;
protected final DriverContext driverContext;