Avoid O(N^2) in VALUES with ordinals grouping (#130576)

Using the VALUES aggregator with ordinals grouping led to accidental 
quadratic complexity. Queries like FROM .. | STATS ... VALUES(field) ...
BY keyword-field are affected by this performance issue. This change
caches a sorted structure - previously used to fix a similar O(N^2)
problem when emitting the output block - during the merging phase of the
OrdinalGroupingOperator.
This commit is contained in:
Nhat Nguyen 2025-07-07 14:51:29 -07:00 committed by GitHub
parent 02b2f5eb66
commit 59df1bfd51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 570 additions and 221 deletions

View File

@ -95,7 +95,8 @@ public class ValuesAggregatorBenchmark {
try { try {
for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) { for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) {
for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) { for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) {
run(Integer.parseInt(groups), dataType, 10); run(Integer.parseInt(groups), dataType, 10, 0);
run(Integer.parseInt(groups), dataType, 10, 1);
} }
} }
} catch (NoSuchFieldException e) { } catch (NoSuchFieldException e) {
@ -113,7 +114,10 @@ public class ValuesAggregatorBenchmark {
@Param({ BYTES_REF, INT, LONG }) @Param({ BYTES_REF, INT, LONG })
public String dataType; public String dataType;
private static Operator operator(DriverContext driverContext, int groups, String dataType) { @Param({ "0", "1" })
public int numOrdinalMerges;
private static Operator operator(DriverContext driverContext, int groups, String dataType, int numOrdinalMerges) {
if (groups == 1) { if (groups == 1) {
return new AggregationOperator( return new AggregationOperator(
List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)), List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
@ -125,7 +129,24 @@ public class ValuesAggregatorBenchmark {
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))), List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false), () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
driverContext driverContext
); ) {
@Override
public Page getOutput() {
mergeOrdinal();
return super.getOutput();
}
// simulate OrdinalsGroupingOperator
void mergeOrdinal() {
var merged = supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1)).apply(driverContext);
for (int i = 0; i < numOrdinalMerges; i++) {
for (int p = 0; p < groups; p++) {
merged.addIntermediateRow(p, aggregators.getFirst(), p);
}
}
aggregators.set(0, merged);
}
};
} }
private static AggregatorFunctionSupplier supplier(String dataType) { private static AggregatorFunctionSupplier supplier(String dataType) {
@ -331,12 +352,12 @@ public class ValuesAggregatorBenchmark {
@Benchmark @Benchmark
public void run() { public void run() {
run(groups, dataType, OP_COUNT); run(groups, dataType, OP_COUNT, numOrdinalMerges);
} }
private static void run(int groups, String dataType, int opCount) { private static void run(int groups, String dataType, int opCount, int numOrdinalMerges) {
DriverContext driverContext = driverContext(); DriverContext driverContext = driverContext();
try (Operator operator = operator(driverContext, groups, dataType)) { try (Operator operator = operator(driverContext, groups, dataType, numOrdinalMerges)) {
Page page = page(groups, dataType); Page page = page(groups, dataType);
for (int i = 0; i < opCount; i++) { for (int i = 0; i < opCount; i++) {
operator.addInput(page.shallowCopy()); operator.addInput(page.shallowCopy());

View File

@ -0,0 +1,5 @@
pr: 130576
summary: Avoid O(N^2) in VALUES with ordinals grouping
area: ES|QL
type: bug
issues: []

View File

@ -24,6 +24,7 @@ import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables; import org.elasticsearch.core.Releasables;
/** /**
@ -55,8 +56,8 @@ class ValuesBytesRefAggregator {
return state.toBlock(driverContext.blockFactory()); return state.toBlock(driverContext.blockFactory());
} }
public static GroupingState initGrouping(BigArrays bigArrays) { public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(bigArrays); return new GroupingState(driverContext);
} }
public static GroupingAggregatorFunction.AddInput wrapAddInput( public static GroupingAggregatorFunction.AddInput wrapAddInput(
@ -76,7 +77,7 @@ class ValuesBytesRefAggregator {
} }
public static void combine(GroupingState state, int groupId, BytesRef v) { public static void combine(GroupingState state, int groupId, BytesRef v) {
state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); state.addValue(groupId, v);
} }
public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) {
@ -84,17 +85,20 @@ class ValuesBytesRefAggregator {
int start = values.getFirstValueIndex(valuesPosition); int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition); int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
combine(state, groupId, values.getBytesRef(i, scratch)); state.addValue(groupId, values.getBytesRef(i, scratch));
} }
} }
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
BytesRef scratch = new BytesRef(); var sorted = state.sortedForOrdinalMerging(current);
for (int id = 0; id < state.values.size(); id++) { if (statePosition > state.maxGroupId) {
if (state.values.getKey1(id) == statePosition) { return;
long value = state.values.getKey2(id); }
combine(current, currentGroupId, state.bytes.get(value, scratch)); var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
} var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValueOrdinal(currentGroupId, id);
} }
} }
@ -138,6 +142,22 @@ class ValuesBytesRefAggregator {
} }
} }
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
* If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)},
* these are ordinals referring to the {@link GroupingState#bytes} in the target state.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/** /**
* State for a grouped {@code VALUES} aggregation. This implementation * State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering * emphasizes collect-time performance over the performance of rendering
@ -146,15 +166,20 @@ class ValuesBytesRefAggregator {
* collector operation. But at least it's fairly simple. * collector operation. But at least it's fairly simple.
*/ */
public static class GroupingState implements GroupingAggregatorState { public static class GroupingState implements GroupingAggregatorState {
final LongLongHash values; private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongLongHash values;
BytesRefHash bytes; BytesRefHash bytes;
private GroupingState(BigArrays bigArrays) { private Sorted sortedForOrdinalMerging = null;
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
LongLongHash _values = null; LongLongHash _values = null;
BytesRefHash _bytes = null; BytesRefHash _bytes = null;
try { try {
_values = new LongLongHash(1, bigArrays); _values = new LongLongHash(1, driverContext.bigArrays());
_bytes = new BytesRefHash(1, bigArrays); _bytes = new BytesRefHash(1, driverContext.bigArrays());
values = _values; values = _values;
bytes = _bytes; bytes = _bytes;
@ -171,6 +196,16 @@ class ValuesBytesRefAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected); blocks[offset] = toBlock(driverContext.blockFactory(), selected);
} }
void addValueOrdinal(int groupId, long valueOrdinal) {
values.add(groupId, valueOrdinal);
maxGroupId = Math.max(maxGroupId, groupId);
}
void addValue(int groupId, BytesRef v) {
values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v)));
maxGroupId = Math.max(maxGroupId, groupId);
}
/** /**
* Builds a {@link Block} with the unique values collected for the {@code #selected} * Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg. * groups. This is the implementation of the final and intermediate results of the agg.
@ -180,8 +215,19 @@ class ValuesBytesRefAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount()); return blockFactory.newConstantNullBlock(selected.getPositionCount());
} }
try (var sorted = buildSorted(selected)) {
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
} else {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0; long selectedCountsSize = 0;
long idsSize = 0; long idsSize = 0;
Sorted sorted = null;
try { try {
/* /*
* Get a count of all groups less than the maximum selected group. Count * Get a count of all groups less than the maximum selected group. Count
@ -256,16 +302,44 @@ class ValuesBytesRefAggregator {
ids[selectedCounts[group]++] = id; ids[selectedCounts[group]++] = id;
} }
} }
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { final long totalMemoryUsed = selectedCountsSize + idsSize;
return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids); sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
} else { return sorted;
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
}
} finally { } finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize); if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
} }
} }
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
// hash all the bytes to the destination to avoid hashing them multiple times
BytesRef scratch = new BytesRef();
final int totalValue = Math.toIntExact(bytes.size());
blockFactory.adjustBreaker((long) totalValue * Integer.BYTES);
try {
final int[] mappedIds = new int[totalValue];
for (int i = 0; i < totalValue; i++) {
var v = bytes.get(i, scratch);
mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v)));
}
// no longer need the bytes
bytes.close();
bytes = null;
for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) {
sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))];
}
} finally {
blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES);
}
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/* /*
* Insert the ids in order. * Insert the ids in order.
@ -279,11 +353,11 @@ class ValuesBytesRefAggregator {
int count = end - start; int count = end - start;
switch (count) { switch (count) {
case 0 -> builder.appendNull(); case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start], scratch); case 1 -> builder.appendBytesRef(getValue(ids[start], scratch));
default -> { default -> {
builder.beginPositionEntry(); builder.beginPositionEntry();
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
append(builder, ids[i], scratch); builder.appendBytesRef(getValue(ids[i], scratch));
} }
builder.endPositionEntry(); builder.endPositionEntry();
} }
@ -331,9 +405,8 @@ class ValuesBytesRefAggregator {
} }
} }
private void append(BytesRefBlock.Builder builder, int id, BytesRef scratch) { BytesRef getValue(int valueId, BytesRef scratch) {
BytesRef value = bytes.get(values.getKey2(id), scratch); return bytes.get(values.getKey2(valueId), scratch);
builder.appendBytesRef(value);
} }
@Override @Override
@ -343,7 +416,7 @@ class ValuesBytesRefAggregator {
@Override @Override
public void close() { public void close() {
Releasables.closeExpectNoException(values, bytes); Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging);
} }
} }
} }

View File

@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/** /**
* Aggregates field values for double. * Aggregates field values for double.
@ -48,28 +50,32 @@ class ValuesDoubleAggregator {
return state.toBlock(driverContext.blockFactory()); return state.toBlock(driverContext.blockFactory());
} }
public static GroupingState initGrouping(BigArrays bigArrays) { public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(bigArrays); return new GroupingState(driverContext);
} }
public static void combine(GroupingState state, int groupId, double v) { public static void combine(GroupingState state, int groupId, double v) {
state.values.add(groupId, Double.doubleToLongBits(v)); state.addValue(groupId, v);
} }
public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) { public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition); int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition); int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
combine(state, groupId, values.getDouble(i)); state.addValue(groupId, values.getDouble(i));
} }
} }
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
for (int id = 0; id < state.values.size(); id++) { var sorted = state.sortedForOrdinalMerging(current);
if (state.values.getKey1(id) == statePosition) { if (statePosition > state.maxGroupId) {
double value = Double.longBitsToDouble(state.values.getKey2(id)); return;
combine(current, currentGroupId, value); }
} var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValue(currentGroupId, state.getValue(id));
} }
} }
@ -112,6 +118,20 @@ class ValuesDoubleAggregator {
} }
} }
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/** /**
* State for a grouped {@code VALUES} aggregation. This implementation * State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering * emphasizes collect-time performance over the performance of rendering
@ -120,10 +140,15 @@ class ValuesDoubleAggregator {
* collector operation. But at least it's fairly simple. * collector operation. But at least it's fairly simple.
*/ */
public static class GroupingState implements GroupingAggregatorState { public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongLongHash values; private final LongLongHash values;
private GroupingState(BigArrays bigArrays) { private Sorted sortedForOrdinalMerging = null;
values = new LongLongHash(1, bigArrays);
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
values = new LongLongHash(1, driverContext.bigArrays());
} }
@Override @Override
@ -131,6 +156,11 @@ class ValuesDoubleAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected); blocks[offset] = toBlock(driverContext.blockFactory(), selected);
} }
void addValue(int groupId, double v) {
values.add(groupId, Double.doubleToLongBits(v));
maxGroupId = Math.max(maxGroupId, groupId);
}
/** /**
* Builds a {@link Block} with the unique values collected for the {@code #selected} * Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg. * groups. This is the implementation of the final and intermediate results of the agg.
@ -140,8 +170,15 @@ class ValuesDoubleAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount()); return blockFactory.newConstantNullBlock(selected.getPositionCount());
} }
try (var sorted = buildSorted(selected)) {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0; long selectedCountsSize = 0;
long idsSize = 0; long idsSize = 0;
Sorted sorted = null;
try { try {
/* /*
* Get a count of all groups less than the maximum selected group. Count * Get a count of all groups less than the maximum selected group. Count
@ -216,12 +253,25 @@ class ValuesDoubleAggregator {
ids[selectedCounts[group]++] = id; ids[selectedCounts[group]++] = id;
} }
} }
return buildOutputBlock(blockFactory, selected, selectedCounts, ids); final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally { } finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize); if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
} }
} }
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/* /*
* Insert the ids in order. * Insert the ids in order.
@ -234,11 +284,11 @@ class ValuesDoubleAggregator {
int count = end - start; int count = end - start;
switch (count) { switch (count) {
case 0 -> builder.appendNull(); case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]); case 1 -> builder.appendDouble(getValue(ids[start]));
default -> { default -> {
builder.beginPositionEntry(); builder.beginPositionEntry();
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
append(builder, ids[i]); builder.appendDouble(getValue(ids[i]));
} }
builder.endPositionEntry(); builder.endPositionEntry();
} }
@ -249,9 +299,8 @@ class ValuesDoubleAggregator {
} }
} }
private void append(DoubleBlock.Builder builder, int id) { double getValue(int valueId) {
double value = Double.longBitsToDouble(values.getKey2(id)); return Double.longBitsToDouble(values.getKey2(valueId));
builder.appendDouble(value);
} }
@Override @Override
@ -261,7 +310,7 @@ class ValuesDoubleAggregator {
@Override @Override
public void close() { public void close() {
values.close(); Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
} }
} }
} }

View File

@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/** /**
* Aggregates field values for float. * Aggregates field values for float.
@ -47,34 +49,32 @@ class ValuesFloatAggregator {
return state.toBlock(driverContext.blockFactory()); return state.toBlock(driverContext.blockFactory());
} }
public static GroupingState initGrouping(BigArrays bigArrays) { public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(bigArrays); return new GroupingState(driverContext);
} }
public static void combine(GroupingState state, int groupId, float v) { public static void combine(GroupingState state, int groupId, float v) {
/* state.addValue(groupId, v);
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
} }
public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, int valuesPosition) { public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition); int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition); int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
combine(state, groupId, values.getFloat(i)); state.addValue(groupId, values.getFloat(i));
} }
} }
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
for (int id = 0; id < state.values.size(); id++) { var sorted = state.sortedForOrdinalMerging(current);
long both = state.values.get(id); if (statePosition > state.maxGroupId) {
int group = (int) (both >>> Float.SIZE); return;
if (group == statePosition) { }
float value = Float.intBitsToFloat((int) both); var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
combine(current, currentGroupId, value); var end = sorted.counts[statePosition];
} for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValue(currentGroupId, state.getValue(id));
} }
} }
@ -117,6 +117,20 @@ class ValuesFloatAggregator {
} }
} }
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/** /**
* State for a grouped {@code VALUES} aggregation. This implementation * State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering * emphasizes collect-time performance over the performance of rendering
@ -125,10 +139,15 @@ class ValuesFloatAggregator {
* collector operation. But at least it's fairly simple. * collector operation. But at least it's fairly simple.
*/ */
public static class GroupingState implements GroupingAggregatorState { public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongHash values; private final LongHash values;
private GroupingState(BigArrays bigArrays) { private Sorted sortedForOrdinalMerging = null;
values = new LongHash(1, bigArrays);
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
values = new LongHash(1, driverContext.bigArrays());
} }
@Override @Override
@ -136,6 +155,15 @@ class ValuesFloatAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected); blocks[offset] = toBlock(driverContext.blockFactory(), selected);
} }
void addValue(int groupId, float v) {
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
maxGroupId = Math.max(maxGroupId, groupId);
}
/** /**
* Builds a {@link Block} with the unique values collected for the {@code #selected} * Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg. * groups. This is the implementation of the final and intermediate results of the agg.
@ -145,8 +173,15 @@ class ValuesFloatAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount()); return blockFactory.newConstantNullBlock(selected.getPositionCount());
} }
try (var sorted = buildSorted(selected)) {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0; long selectedCountsSize = 0;
long idsSize = 0; long idsSize = 0;
Sorted sorted = null;
try { try {
/* /*
* Get a count of all groups less than the maximum selected group. Count * Get a count of all groups less than the maximum selected group. Count
@ -223,12 +258,25 @@ class ValuesFloatAggregator {
ids[selectedCounts[group]++] = id; ids[selectedCounts[group]++] = id;
} }
} }
return buildOutputBlock(blockFactory, selected, selectedCounts, ids); final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally { } finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize); if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
} }
} }
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/* /*
* Insert the ids in order. * Insert the ids in order.
@ -241,11 +289,11 @@ class ValuesFloatAggregator {
int count = end - start; int count = end - start;
switch (count) { switch (count) {
case 0 -> builder.appendNull(); case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]); case 1 -> builder.appendFloat(getValue(ids[start]));
default -> { default -> {
builder.beginPositionEntry(); builder.beginPositionEntry();
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
append(builder, ids[i]); builder.appendFloat(getValue(ids[i]));
} }
builder.endPositionEntry(); builder.endPositionEntry();
} }
@ -256,10 +304,9 @@ class ValuesFloatAggregator {
} }
} }
private void append(FloatBlock.Builder builder, int id) { float getValue(int valueId) {
long both = values.get(id); long both = values.get(valueId);
float value = Float.intBitsToFloat((int) both); return Float.intBitsToFloat((int) both);
builder.appendFloat(value);
} }
@Override @Override
@ -269,7 +316,7 @@ class ValuesFloatAggregator {
@Override @Override
public void close() { public void close() {
values.close(); Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
} }
} }
} }

View File

@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/** /**
* Aggregates field values for int. * Aggregates field values for int.
@ -47,34 +49,32 @@ class ValuesIntAggregator {
return state.toBlock(driverContext.blockFactory()); return state.toBlock(driverContext.blockFactory());
} }
public static GroupingState initGrouping(BigArrays bigArrays) { public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(bigArrays); return new GroupingState(driverContext);
} }
public static void combine(GroupingState state, int groupId, int v) { public static void combine(GroupingState state, int groupId, int v) {
/* state.addValue(groupId, v);
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
} }
public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) { public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition); int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition); int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
combine(state, groupId, values.getInt(i)); state.addValue(groupId, values.getInt(i));
} }
} }
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
for (int id = 0; id < state.values.size(); id++) { var sorted = state.sortedForOrdinalMerging(current);
long both = state.values.get(id); if (statePosition > state.maxGroupId) {
int group = (int) (both >>> Integer.SIZE); return;
if (group == statePosition) { }
int value = (int) both; var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
combine(current, currentGroupId, value); var end = sorted.counts[statePosition];
} for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValue(currentGroupId, state.getValue(id));
} }
} }
@ -117,6 +117,20 @@ class ValuesIntAggregator {
} }
} }
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/** /**
* State for a grouped {@code VALUES} aggregation. This implementation * State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering * emphasizes collect-time performance over the performance of rendering
@ -125,10 +139,15 @@ class ValuesIntAggregator {
* collector operation. But at least it's fairly simple. * collector operation. But at least it's fairly simple.
*/ */
public static class GroupingState implements GroupingAggregatorState { public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongHash values; private final LongHash values;
private GroupingState(BigArrays bigArrays) { private Sorted sortedForOrdinalMerging = null;
values = new LongHash(1, bigArrays);
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
values = new LongHash(1, driverContext.bigArrays());
} }
@Override @Override
@ -136,6 +155,15 @@ class ValuesIntAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected); blocks[offset] = toBlock(driverContext.blockFactory(), selected);
} }
void addValue(int groupId, int v) {
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
maxGroupId = Math.max(maxGroupId, groupId);
}
/** /**
* Builds a {@link Block} with the unique values collected for the {@code #selected} * Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg. * groups. This is the implementation of the final and intermediate results of the agg.
@ -145,8 +173,15 @@ class ValuesIntAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount()); return blockFactory.newConstantNullBlock(selected.getPositionCount());
} }
try (var sorted = buildSorted(selected)) {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0; long selectedCountsSize = 0;
long idsSize = 0; long idsSize = 0;
Sorted sorted = null;
try { try {
/* /*
* Get a count of all groups less than the maximum selected group. Count * Get a count of all groups less than the maximum selected group. Count
@ -223,12 +258,25 @@ class ValuesIntAggregator {
ids[selectedCounts[group]++] = id; ids[selectedCounts[group]++] = id;
} }
} }
return buildOutputBlock(blockFactory, selected, selectedCounts, ids); final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally { } finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize); if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
} }
} }
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/* /*
* Insert the ids in order. * Insert the ids in order.
@ -241,11 +289,11 @@ class ValuesIntAggregator {
int count = end - start; int count = end - start;
switch (count) { switch (count) {
case 0 -> builder.appendNull(); case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]); case 1 -> builder.appendInt(getValue(ids[start]));
default -> { default -> {
builder.beginPositionEntry(); builder.beginPositionEntry();
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
append(builder, ids[i]); builder.appendInt(getValue(ids[i]));
} }
builder.endPositionEntry(); builder.endPositionEntry();
} }
@ -256,10 +304,9 @@ class ValuesIntAggregator {
} }
} }
private void append(IntBlock.Builder builder, int id) { int getValue(int valueId) {
long both = values.get(id); long both = values.get(valueId);
int value = (int) both; return (int) both;
builder.appendInt(value);
} }
@Override @Override
@ -269,7 +316,7 @@ class ValuesIntAggregator {
@Override @Override
public void close() { public void close() {
values.close(); Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
} }
} }
} }

View File

@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/** /**
* Aggregates field values for long. * Aggregates field values for long.
@ -48,28 +50,32 @@ class ValuesLongAggregator {
return state.toBlock(driverContext.blockFactory()); return state.toBlock(driverContext.blockFactory());
} }
public static GroupingState initGrouping(BigArrays bigArrays) { public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(bigArrays); return new GroupingState(driverContext);
} }
public static void combine(GroupingState state, int groupId, long v) { public static void combine(GroupingState state, int groupId, long v) {
state.values.add(groupId, v); state.addValue(groupId, v);
} }
public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) { public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) {
int start = values.getFirstValueIndex(valuesPosition); int start = values.getFirstValueIndex(valuesPosition);
int end = start + values.getValueCount(valuesPosition); int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
combine(state, groupId, values.getLong(i)); state.addValue(groupId, values.getLong(i));
} }
} }
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
for (int id = 0; id < state.values.size(); id++) { var sorted = state.sortedForOrdinalMerging(current);
if (state.values.getKey1(id) == statePosition) { if (statePosition > state.maxGroupId) {
long value = state.values.getKey2(id); return;
combine(current, currentGroupId, value); }
} var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
current.addValue(currentGroupId, state.getValue(id));
} }
} }
@ -112,6 +118,20 @@ class ValuesLongAggregator {
} }
} }
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/** /**
* State for a grouped {@code VALUES} aggregation. This implementation * State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering * emphasizes collect-time performance over the performance of rendering
@ -120,10 +140,15 @@ class ValuesLongAggregator {
* collector operation. But at least it's fairly simple. * collector operation. But at least it's fairly simple.
*/ */
public static class GroupingState implements GroupingAggregatorState { public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
private final LongLongHash values; private final LongLongHash values;
private GroupingState(BigArrays bigArrays) { private Sorted sortedForOrdinalMerging = null;
values = new LongLongHash(1, bigArrays);
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
values = new LongLongHash(1, driverContext.bigArrays());
} }
@Override @Override
@ -131,6 +156,11 @@ class ValuesLongAggregator {
blocks[offset] = toBlock(driverContext.blockFactory(), selected); blocks[offset] = toBlock(driverContext.blockFactory(), selected);
} }
void addValue(int groupId, long v) {
values.add(groupId, v);
maxGroupId = Math.max(maxGroupId, groupId);
}
/** /**
* Builds a {@link Block} with the unique values collected for the {@code #selected} * Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg. * groups. This is the implementation of the final and intermediate results of the agg.
@ -140,8 +170,15 @@ class ValuesLongAggregator {
return blockFactory.newConstantNullBlock(selected.getPositionCount()); return blockFactory.newConstantNullBlock(selected.getPositionCount());
} }
try (var sorted = buildSorted(selected)) {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0; long selectedCountsSize = 0;
long idsSize = 0; long idsSize = 0;
Sorted sorted = null;
try { try {
/* /*
* Get a count of all groups less than the maximum selected group. Count * Get a count of all groups less than the maximum selected group. Count
@ -216,12 +253,25 @@ class ValuesLongAggregator {
ids[selectedCounts[group]++] = id; ids[selectedCounts[group]++] = id;
} }
} }
return buildOutputBlock(blockFactory, selected, selectedCounts, ids); final long totalMemoryUsed = selectedCountsSize + idsSize;
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return sorted;
} finally { } finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize); if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
} }
} }
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/* /*
* Insert the ids in order. * Insert the ids in order.
@ -234,11 +284,11 @@ class ValuesLongAggregator {
int count = end - start; int count = end - start;
switch (count) { switch (count) {
case 0 -> builder.appendNull(); case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]); case 1 -> builder.appendLong(getValue(ids[start]));
default -> { default -> {
builder.beginPositionEntry(); builder.beginPositionEntry();
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
append(builder, ids[i]); builder.appendLong(getValue(ids[i]));
} }
builder.endPositionEntry(); builder.endPositionEntry();
} }
@ -249,9 +299,8 @@ class ValuesLongAggregator {
} }
} }
private void append(LongBlock.Builder builder, int id) { long getValue(int valueId) {
long value = values.getKey2(id); return values.getKey2(valueId);
builder.appendLong(value);
} }
@Override @Override
@ -261,7 +310,7 @@ class ValuesLongAggregator {
@Override @Override
public void close() { public void close() {
values.close(); Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
} }
} }
} }

View File

@ -43,7 +43,7 @@ public final class ValuesBytesRefGroupingAggregatorFunction implements GroupingA
public static ValuesBytesRefGroupingAggregatorFunction create(List<Integer> channels, public static ValuesBytesRefGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) { DriverContext driverContext) {
return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext.bigArrays()), driverContext); return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext), driverContext);
} }
public static List<IntermediateStateDesc> intermediateStateDesc() { public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -42,7 +42,7 @@ public final class ValuesDoubleGroupingAggregatorFunction implements GroupingAgg
public static ValuesDoubleGroupingAggregatorFunction create(List<Integer> channels, public static ValuesDoubleGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) { DriverContext driverContext) {
return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext); return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext), driverContext);
} }
public static List<IntermediateStateDesc> intermediateStateDesc() { public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -42,7 +42,7 @@ public final class ValuesFloatGroupingAggregatorFunction implements GroupingAggr
public static ValuesFloatGroupingAggregatorFunction create(List<Integer> channels, public static ValuesFloatGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) { DriverContext driverContext) {
return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext); return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext), driverContext);
} }
public static List<IntermediateStateDesc> intermediateStateDesc() { public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -41,7 +41,7 @@ public final class ValuesIntGroupingAggregatorFunction implements GroupingAggreg
public static ValuesIntGroupingAggregatorFunction create(List<Integer> channels, public static ValuesIntGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) { DriverContext driverContext) {
return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext.bigArrays()), driverContext); return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext), driverContext);
} }
public static List<IntermediateStateDesc> intermediateStateDesc() { public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -42,7 +42,7 @@ public final class ValuesLongGroupingAggregatorFunction implements GroupingAggre
public static ValuesLongGroupingAggregatorFunction create(List<Integer> channels, public static ValuesLongGroupingAggregatorFunction create(List<Integer> channels,
DriverContext driverContext) { DriverContext driverContext) {
return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext.bigArrays()), driverContext); return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext), driverContext);
} }
public static List<IntermediateStateDesc> intermediateStateDesc() { public static List<IntermediateStateDesc> intermediateStateDesc() {

View File

@ -55,7 +55,7 @@ final class ValuesBytesRefAggregators {
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) { for (int v = valuesStart; v < valuesEnd; v++) {
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
} }
} }
} }
@ -77,7 +77,7 @@ final class ValuesBytesRefAggregators {
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) { for (int v = valuesStart; v < valuesEnd; v++) {
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
} }
} }
} }
@ -93,7 +93,7 @@ final class ValuesBytesRefAggregators {
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
for (int v = valuesStart; v < valuesEnd; v++) { for (int v = valuesStart; v < valuesEnd; v++) {
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
} }
} }
} }
@ -135,7 +135,7 @@ final class ValuesBytesRefAggregators {
int groupEnd = groupStart + groupIds.getValueCount(groupPosition); int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) { for (int g = groupStart; g < groupEnd; g++) {
int groupId = groupIds.getInt(g); int groupId = groupIds.getInt(g);
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
} }
} }
} }
@ -150,7 +150,7 @@ final class ValuesBytesRefAggregators {
int groupEnd = groupStart + groupIds.getValueCount(groupPosition); int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
for (int g = groupStart; g < groupEnd; g++) { for (int g = groupStart; g < groupEnd; g++) {
int groupId = groupIds.getInt(g); int groupId = groupIds.getInt(g);
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
} }
} }
} }
@ -159,7 +159,7 @@ final class ValuesBytesRefAggregators {
public void add(int positionOffset, IntVector groupIds) { public void add(int positionOffset, IntVector groupIds) {
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
int groupId = groupIds.getInt(groupPosition); int groupId = groupIds.getInt(groupPosition);
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
} }
} }

View File

@ -43,12 +43,9 @@ $if(BytesRef)$
import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
$endif$ $endif$
import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.DriverContext;
$if(BytesRef)$ import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables; import org.elasticsearch.core.Releasables;
$else$
$endif$
/** /**
* Aggregates field values for $type$. * Aggregates field values for $type$.
* This class is generated. Edit @{code X-ValuesAggregator.java.st} instead * This class is generated. Edit @{code X-ValuesAggregator.java.st} instead
@ -90,8 +87,8 @@ $endif$
return state.toBlock(driverContext.blockFactory()); return state.toBlock(driverContext.blockFactory());
} }
public static GroupingState initGrouping(BigArrays bigArrays) { public static GroupingState initGrouping(DriverContext driverContext) {
return new GroupingState(bigArrays); return new GroupingState(driverContext);
} }
$if(BytesRef)$ $if(BytesRef)$
@ -113,25 +110,7 @@ $if(BytesRef)$
$endif$ $endif$
public static void combine(GroupingState state, int groupId, $type$ v) { public static void combine(GroupingState state, int groupId, $type$ v) {
$if(long)$ state.addValue(groupId, v);
state.values.add(groupId, v);
$elseif(double)$
state.values.add(groupId, Double.doubleToLongBits(v));
$elseif(BytesRef)$
state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v)));
$elseif(int)$
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
$elseif(float)$
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
$endif$
} }
public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
@ -142,37 +121,27 @@ $endif$
int end = start + values.getValueCount(valuesPosition); int end = start + values.getValueCount(valuesPosition);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
$if(BytesRef)$ $if(BytesRef)$
combine(state, groupId, values.getBytesRef(i, scratch)); state.addValue(groupId, values.getBytesRef(i, scratch));
$else$ $else$
combine(state, groupId, values.get$Type$(i)); state.addValue(groupId, values.get$Type$(i));
$endif$ $endif$
} }
} }
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
var sorted = state.sortedForOrdinalMerging(current);
if (statePosition > state.maxGroupId) {
return;
}
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
var end = sorted.counts[statePosition];
for (int i = start; i < end; i++) {
int id = sorted.ids[i];
$if(BytesRef)$ $if(BytesRef)$
BytesRef scratch = new BytesRef(); current.addValueOrdinal(currentGroupId, id);
$else$
current.addValue(currentGroupId, state.getValue(id));
$endif$ $endif$
for (int id = 0; id < state.values.size(); id++) {
$if(long||BytesRef)$
if (state.values.getKey1(id) == statePosition) {
long value = state.values.getKey2(id);
$elseif(double)$
if (state.values.getKey1(id) == statePosition) {
double value = Double.longBitsToDouble(state.values.getKey2(id));
$elseif(int)$
long both = state.values.get(id);
int group = (int) (both >>> Integer.SIZE);
if (group == statePosition) {
int value = (int) both;
$elseif(float)$
long both = state.values.get(id);
int group = (int) (both >>> Float.SIZE);
if (group == statePosition) {
float value = Float.intBitsToFloat((int) both);
$endif$
combine(current, currentGroupId, $if(BytesRef)$state.bytes.get(value, scratch)$else$value$endif$);
}
} }
} }
@ -247,6 +216,24 @@ $endif$
} }
} }
/**
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
* and then use it to iterate over the values in order.
*
* @param ids positions of the {@link GroupingState#values} to read.
$if(BytesRef)$
* If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)},
* these are ordinals referring to the {@link GroupingState#bytes} in the target state.
$endif$
*/
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
@Override
public void close() {
releasable.close();
}
}
/** /**
* State for a grouped {@code VALUES} aggregation. This implementation * State for a grouped {@code VALUES} aggregation. This implementation
* emphasizes collect-time performance over the performance of rendering * emphasizes collect-time performance over the performance of rendering
@ -255,26 +242,31 @@ $endif$
* collector operation. But at least it's fairly simple. * collector operation. But at least it's fairly simple.
*/ */
public static class GroupingState implements GroupingAggregatorState { public static class GroupingState implements GroupingAggregatorState {
private int maxGroupId = -1;
private final BlockFactory blockFactory;
$if(long||double)$ $if(long||double)$
private final LongLongHash values; private final LongLongHash values;
$elseif(BytesRef)$ $elseif(BytesRef)$
final LongLongHash values; private final LongLongHash values;
BytesRefHash bytes; BytesRefHash bytes;
$elseif(int||float)$ $elseif(int||float)$
private final LongHash values; private final LongHash values;
$endif$ $endif$
private GroupingState(BigArrays bigArrays) { private Sorted sortedForOrdinalMerging = null;
private GroupingState(DriverContext driverContext) {
this.blockFactory = driverContext.blockFactory();
$if(long||double)$ $if(long||double)$
values = new LongLongHash(1, bigArrays); values = new LongLongHash(1, driverContext.bigArrays());
$elseif(BytesRef)$ $elseif(BytesRef)$
LongLongHash _values = null; LongLongHash _values = null;
BytesRefHash _bytes = null; BytesRefHash _bytes = null;
try { try {
_values = new LongLongHash(1, bigArrays); _values = new LongLongHash(1, driverContext.bigArrays());
_bytes = new BytesRefHash(1, bigArrays); _bytes = new BytesRefHash(1, driverContext.bigArrays());
values = _values; values = _values;
bytes = _bytes; bytes = _bytes;
@ -285,7 +277,7 @@ $elseif(BytesRef)$
Releasables.closeExpectNoException(_values, _bytes); Releasables.closeExpectNoException(_values, _bytes);
} }
$elseif(int||float)$ $elseif(int||float)$
values = new LongHash(1, bigArrays); values = new LongHash(1, driverContext.bigArrays());
$endif$ $endif$
} }
@ -294,6 +286,36 @@ $endif$
blocks[offset] = toBlock(driverContext.blockFactory(), selected); blocks[offset] = toBlock(driverContext.blockFactory(), selected);
} }
$if(BytesRef)$
void addValueOrdinal(int groupId, long valueOrdinal) {
values.add(groupId, valueOrdinal);
maxGroupId = Math.max(maxGroupId, groupId);
}
$endif$
void addValue(int groupId, $type$ v) {
$if(long)$
values.add(groupId, v);
$elseif(double)$
values.add(groupId, Double.doubleToLongBits(v));
$elseif(BytesRef)$
values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v)));
$elseif(int)$
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
$elseif(float)$
/*
* Encode the groupId and value into a single long -
* the top 32 bits for the group, the bottom 32 for the value.
*/
values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
$endif$
maxGroupId = Math.max(maxGroupId, groupId);
}
/** /**
* Builds a {@link Block} with the unique values collected for the {@code #selected} * Builds a {@link Block} with the unique values collected for the {@code #selected}
* groups. This is the implementation of the final and intermediate results of the agg. * groups. This is the implementation of the final and intermediate results of the agg.
@ -303,8 +325,23 @@ $endif$
return blockFactory.newConstantNullBlock(selected.getPositionCount()); return blockFactory.newConstantNullBlock(selected.getPositionCount());
} }
try (var sorted = buildSorted(selected)) {
$if(BytesRef)$
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
} else {
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
}
$else$
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
$endif$
}
}
private Sorted buildSorted(IntVector selected) {
long selectedCountsSize = 0; long selectedCountsSize = 0;
long idsSize = 0; long idsSize = 0;
Sorted sorted = null;
try { try {
/* /*
* Get a count of all groups less than the maximum selected group. Count * Get a count of all groups less than the maximum selected group. Count
@ -379,30 +416,56 @@ $endif$
idsSize = adjust; idsSize = adjust;
int[] ids = new int[total]; int[] ids = new int[total];
for (int id = 0; id < values.size(); id++) { for (int id = 0; id < values.size(); id++) {
$if(long||BytesRef||double)$ $if(long||BytesRef||double)$
int group = (int) values.getKey1(id); int group = (int) values.getKey1(id);
$elseif(float||int)$ $elseif(float||int)$
long both = values.get(id); long both = values.get(id);
int group = (int) (both >>> Float.SIZE); int group = (int) (both >>> Float.SIZE);
$endif$ $endif$
if (group < selectedCounts.length && selectedCounts[group] >= 0) { if (group < selectedCounts.length && selectedCounts[group] >= 0) {
ids[selectedCounts[group]++] = id; ids[selectedCounts[group]++] = id;
} }
} }
$if(BytesRef)$ final long totalMemoryUsed = selectedCountsSize + idsSize;
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids); return sorted;
} else {
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
}
$else$
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
$endif$
} finally { } finally {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize); if (sorted == null) {
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
}
} }
} }
private Sorted sortedForOrdinalMerging(GroupingState other) {
if (sortedForOrdinalMerging == null) {
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
sortedForOrdinalMerging = buildSorted(selected);
$if(BytesRef)$
// hash all the bytes to the destination to avoid hashing them multiple times
BytesRef scratch = new BytesRef();
final int totalValue = Math.toIntExact(bytes.size());
blockFactory.adjustBreaker((long) totalValue * Integer.BYTES);
try {
final int[] mappedIds = new int[totalValue];
for (int i = 0; i < totalValue; i++) {
var v = bytes.get(i, scratch);
mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v)));
}
// no longer need the bytes
bytes.close();
bytes = null;
for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) {
sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))];
}
} finally {
blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES);
}
$endif$
}
}
return sortedForOrdinalMerging;
}
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
/* /*
* Insert the ids in order. * Insert the ids in order.
@ -418,11 +481,11 @@ $endif$
int count = end - start; int count = end - start;
switch (count) { switch (count) {
case 0 -> builder.appendNull(); case 0 -> builder.appendNull();
case 1 -> append(builder, ids[start]$if(BytesRef)$, scratch$endif$); case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$));
default -> { default -> {
builder.beginPositionEntry(); builder.beginPositionEntry();
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
append(builder, ids[i]$if(BytesRef)$, scratch$endif$); builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$));
} }
builder.endPositionEntry(); builder.endPositionEntry();
} }
@ -470,29 +533,24 @@ $if(BytesRef)$
} }
} }
} }
$endif$
private void append($Type$Block.Builder builder, int id, BytesRef scratch) { $type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) {
BytesRef value = bytes.get(values.getKey2(id), scratch); $if(BytesRef)$
builder.appendBytesRef(value); return bytes.get(values.getKey2(valueId), scratch);
} $elseif(long)$
return values.getKey2(valueId);
$else$
private void append($Type$Block.Builder builder, int id) {
$if(long)$
long value = values.getKey2(id);
$elseif(double)$ $elseif(double)$
double value = Double.longBitsToDouble(values.getKey2(id)); return Double.longBitsToDouble(values.getKey2(valueId));
$elseif(float)$ $elseif(float)$
long both = values.get(id); long both = values.get(valueId);
float value = Float.intBitsToFloat((int) both); return Float.intBitsToFloat((int) both);
$elseif(int)$ $elseif(int)$
long both = values.get(id); long both = values.get(valueId);
int value = (int) both; return (int) both;
$endif$ $endif$
builder.append$Type$(value);
} }
$endif$
@Override @Override
public void enableGroupIdTracking(SeenGroupIds seen) { public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block // we figure out seen values from nulls on the values block
@ -501,9 +559,9 @@ $endif$
@Override @Override
public void close() { public void close() {
$if(BytesRef)$ $if(BytesRef)$
Releasables.closeExpectNoException(values, bytes); Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging);
$else$ $else$
values.close(); Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
$endif$ $endif$
} }
} }

View File

@ -85,7 +85,7 @@ public class HashAggregationOperator implements Operator {
private final BlockHash blockHash; private final BlockHash blockHash;
private final List<GroupingAggregator> aggregators; protected final List<GroupingAggregator> aggregators;
protected final DriverContext driverContext; protected final DriverContext driverContext;