Avoid O(N^2) in VALUES with ordinals grouping (#130576)
Using the VALUES aggregator with ordinals grouping led to accidental quadratic complexity. Queries like FROM .. | STATS ... VALUES(field) ... BY keyword-field are affected by this performance issue. This change caches a sorted structure - previously used to fix a similar O(N^2) problem when emitting the output block - during the merging phase of the OrdinalGroupingOperator.
This commit is contained in:
parent
02b2f5eb66
commit
59df1bfd51
|
@ -95,7 +95,8 @@ public class ValuesAggregatorBenchmark {
|
|||
try {
|
||||
for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) {
|
||||
for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) {
|
||||
run(Integer.parseInt(groups), dataType, 10);
|
||||
run(Integer.parseInt(groups), dataType, 10, 0);
|
||||
run(Integer.parseInt(groups), dataType, 10, 1);
|
||||
}
|
||||
}
|
||||
} catch (NoSuchFieldException e) {
|
||||
|
@ -113,7 +114,10 @@ public class ValuesAggregatorBenchmark {
|
|||
@Param({ BYTES_REF, INT, LONG })
|
||||
public String dataType;
|
||||
|
||||
private static Operator operator(DriverContext driverContext, int groups, String dataType) {
|
||||
@Param({ "0", "1" })
|
||||
public int numOrdinalMerges;
|
||||
|
||||
private static Operator operator(DriverContext driverContext, int groups, String dataType, int numOrdinalMerges) {
|
||||
if (groups == 1) {
|
||||
return new AggregationOperator(
|
||||
List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
|
||||
|
@ -125,7 +129,24 @@ public class ValuesAggregatorBenchmark {
|
|||
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
|
||||
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
|
||||
driverContext
|
||||
);
|
||||
) {
|
||||
@Override
|
||||
public Page getOutput() {
|
||||
mergeOrdinal();
|
||||
return super.getOutput();
|
||||
}
|
||||
|
||||
// simulate OrdinalsGroupingOperator
|
||||
void mergeOrdinal() {
|
||||
var merged = supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1)).apply(driverContext);
|
||||
for (int i = 0; i < numOrdinalMerges; i++) {
|
||||
for (int p = 0; p < groups; p++) {
|
||||
merged.addIntermediateRow(p, aggregators.getFirst(), p);
|
||||
}
|
||||
}
|
||||
aggregators.set(0, merged);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private static AggregatorFunctionSupplier supplier(String dataType) {
|
||||
|
@ -331,12 +352,12 @@ public class ValuesAggregatorBenchmark {
|
|||
|
||||
@Benchmark
|
||||
public void run() {
|
||||
run(groups, dataType, OP_COUNT);
|
||||
run(groups, dataType, OP_COUNT, numOrdinalMerges);
|
||||
}
|
||||
|
||||
private static void run(int groups, String dataType, int opCount) {
|
||||
private static void run(int groups, String dataType, int opCount, int numOrdinalMerges) {
|
||||
DriverContext driverContext = driverContext();
|
||||
try (Operator operator = operator(driverContext, groups, dataType)) {
|
||||
try (Operator operator = operator(driverContext, groups, dataType, numOrdinalMerges)) {
|
||||
Page page = page(groups, dataType);
|
||||
for (int i = 0; i < opCount; i++) {
|
||||
operator.addInput(page.shallowCopy());
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
pr: 130576
|
||||
summary: Avoid O(N^2) in VALUES with ordinals grouping
|
||||
area: ES|QL
|
||||
type: bug
|
||||
issues: []
|
|
@ -24,6 +24,7 @@ import org.elasticsearch.compute.data.IntBlock;
|
|||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
/**
|
||||
|
@ -55,8 +56,8 @@ class ValuesBytesRefAggregator {
|
|||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays) {
|
||||
return new GroupingState(bigArrays);
|
||||
public static GroupingState initGrouping(DriverContext driverContext) {
|
||||
return new GroupingState(driverContext);
|
||||
}
|
||||
|
||||
public static GroupingAggregatorFunction.AddInput wrapAddInput(
|
||||
|
@ -76,7 +77,7 @@ class ValuesBytesRefAggregator {
|
|||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, BytesRef v) {
|
||||
state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v)));
|
||||
state.addValue(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) {
|
||||
|
@ -84,17 +85,20 @@ class ValuesBytesRefAggregator {
|
|||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.getBytesRef(i, scratch));
|
||||
state.addValue(groupId, values.getBytesRef(i, scratch));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
|
||||
BytesRef scratch = new BytesRef();
|
||||
for (int id = 0; id < state.values.size(); id++) {
|
||||
if (state.values.getKey1(id) == statePosition) {
|
||||
long value = state.values.getKey2(id);
|
||||
combine(current, currentGroupId, state.bytes.get(value, scratch));
|
||||
}
|
||||
var sorted = state.sortedForOrdinalMerging(current);
|
||||
if (statePosition > state.maxGroupId) {
|
||||
return;
|
||||
}
|
||||
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
|
||||
var end = sorted.counts[statePosition];
|
||||
for (int i = start; i < end; i++) {
|
||||
int id = sorted.ids[i];
|
||||
current.addValueOrdinal(currentGroupId, id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,6 +142,22 @@ class ValuesBytesRefAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
|
||||
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
|
||||
* and then use it to iterate over the values in order.
|
||||
*
|
||||
* @param ids positions of the {@link GroupingState#values} to read.
|
||||
* If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)},
|
||||
* these are ordinals referring to the {@link GroupingState#bytes} in the target state.
|
||||
*/
|
||||
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
|
||||
@Override
|
||||
public void close() {
|
||||
releasable.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* State for a grouped {@code VALUES} aggregation. This implementation
|
||||
* emphasizes collect-time performance over the performance of rendering
|
||||
|
@ -146,15 +166,20 @@ class ValuesBytesRefAggregator {
|
|||
* collector operation. But at least it's fairly simple.
|
||||
*/
|
||||
public static class GroupingState implements GroupingAggregatorState {
|
||||
final LongLongHash values;
|
||||
private int maxGroupId = -1;
|
||||
private final BlockFactory blockFactory;
|
||||
private final LongLongHash values;
|
||||
BytesRefHash bytes;
|
||||
|
||||
private GroupingState(BigArrays bigArrays) {
|
||||
private Sorted sortedForOrdinalMerging = null;
|
||||
|
||||
private GroupingState(DriverContext driverContext) {
|
||||
this.blockFactory = driverContext.blockFactory();
|
||||
LongLongHash _values = null;
|
||||
BytesRefHash _bytes = null;
|
||||
try {
|
||||
_values = new LongLongHash(1, bigArrays);
|
||||
_bytes = new BytesRefHash(1, bigArrays);
|
||||
_values = new LongLongHash(1, driverContext.bigArrays());
|
||||
_bytes = new BytesRefHash(1, driverContext.bigArrays());
|
||||
|
||||
values = _values;
|
||||
bytes = _bytes;
|
||||
|
@ -171,6 +196,16 @@ class ValuesBytesRefAggregator {
|
|||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
void addValueOrdinal(int groupId, long valueOrdinal) {
|
||||
values.add(groupId, valueOrdinal);
|
||||
maxGroupId = Math.max(maxGroupId, groupId);
|
||||
}
|
||||
|
||||
void addValue(int groupId, BytesRef v) {
|
||||
values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v)));
|
||||
maxGroupId = Math.max(maxGroupId, groupId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a {@link Block} with the unique values collected for the {@code #selected}
|
||||
* groups. This is the implementation of the final and intermediate results of the agg.
|
||||
|
@ -180,8 +215,19 @@ class ValuesBytesRefAggregator {
|
|||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
try (var sorted = buildSorted(selected)) {
|
||||
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
|
||||
return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
} else {
|
||||
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted buildSorted(IntVector selected) {
|
||||
long selectedCountsSize = 0;
|
||||
long idsSize = 0;
|
||||
Sorted sorted = null;
|
||||
try {
|
||||
/*
|
||||
* Get a count of all groups less than the maximum selected group. Count
|
||||
|
@ -256,16 +302,44 @@ class ValuesBytesRefAggregator {
|
|||
ids[selectedCounts[group]++] = id;
|
||||
}
|
||||
}
|
||||
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
|
||||
return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
} else {
|
||||
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
}
|
||||
final long totalMemoryUsed = selectedCountsSize + idsSize;
|
||||
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
|
||||
return sorted;
|
||||
} finally {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
if (sorted == null) {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted sortedForOrdinalMerging(GroupingState other) {
|
||||
if (sortedForOrdinalMerging == null) {
|
||||
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
|
||||
sortedForOrdinalMerging = buildSorted(selected);
|
||||
// hash all the bytes to the destination to avoid hashing them multiple times
|
||||
BytesRef scratch = new BytesRef();
|
||||
final int totalValue = Math.toIntExact(bytes.size());
|
||||
blockFactory.adjustBreaker((long) totalValue * Integer.BYTES);
|
||||
try {
|
||||
final int[] mappedIds = new int[totalValue];
|
||||
for (int i = 0; i < totalValue; i++) {
|
||||
var v = bytes.get(i, scratch);
|
||||
mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v)));
|
||||
}
|
||||
// no longer need the bytes
|
||||
bytes.close();
|
||||
bytes = null;
|
||||
for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) {
|
||||
sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))];
|
||||
}
|
||||
} finally {
|
||||
blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES);
|
||||
}
|
||||
}
|
||||
}
|
||||
return sortedForOrdinalMerging;
|
||||
}
|
||||
|
||||
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
|
||||
/*
|
||||
* Insert the ids in order.
|
||||
|
@ -279,11 +353,11 @@ class ValuesBytesRefAggregator {
|
|||
int count = end - start;
|
||||
switch (count) {
|
||||
case 0 -> builder.appendNull();
|
||||
case 1 -> append(builder, ids[start], scratch);
|
||||
case 1 -> builder.appendBytesRef(getValue(ids[start], scratch));
|
||||
default -> {
|
||||
builder.beginPositionEntry();
|
||||
for (int i = start; i < end; i++) {
|
||||
append(builder, ids[i], scratch);
|
||||
builder.appendBytesRef(getValue(ids[i], scratch));
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
@ -331,9 +405,8 @@ class ValuesBytesRefAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
private void append(BytesRefBlock.Builder builder, int id, BytesRef scratch) {
|
||||
BytesRef value = bytes.get(values.getKey2(id), scratch);
|
||||
builder.appendBytesRef(value);
|
||||
BytesRef getValue(int valueId, BytesRef scratch) {
|
||||
return bytes.get(values.getKey2(valueId), scratch);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -343,7 +416,7 @@ class ValuesBytesRefAggregator {
|
|||
|
||||
@Override
|
||||
public void close() {
|
||||
Releasables.closeExpectNoException(values, bytes);
|
||||
Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.BlockFactory;
|
|||
import org.elasticsearch.compute.data.DoubleBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
/**
|
||||
* Aggregates field values for double.
|
||||
|
@ -48,28 +50,32 @@ class ValuesDoubleAggregator {
|
|||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays) {
|
||||
return new GroupingState(bigArrays);
|
||||
public static GroupingState initGrouping(DriverContext driverContext) {
|
||||
return new GroupingState(driverContext);
|
||||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, double v) {
|
||||
state.values.add(groupId, Double.doubleToLongBits(v));
|
||||
state.addValue(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, int valuesPosition) {
|
||||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.getDouble(i));
|
||||
state.addValue(groupId, values.getDouble(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
|
||||
for (int id = 0; id < state.values.size(); id++) {
|
||||
if (state.values.getKey1(id) == statePosition) {
|
||||
double value = Double.longBitsToDouble(state.values.getKey2(id));
|
||||
combine(current, currentGroupId, value);
|
||||
}
|
||||
var sorted = state.sortedForOrdinalMerging(current);
|
||||
if (statePosition > state.maxGroupId) {
|
||||
return;
|
||||
}
|
||||
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
|
||||
var end = sorted.counts[statePosition];
|
||||
for (int i = start; i < end; i++) {
|
||||
int id = sorted.ids[i];
|
||||
current.addValue(currentGroupId, state.getValue(id));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,6 +118,20 @@ class ValuesDoubleAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
|
||||
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
|
||||
* and then use it to iterate over the values in order.
|
||||
*
|
||||
* @param ids positions of the {@link GroupingState#values} to read.
|
||||
*/
|
||||
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
|
||||
@Override
|
||||
public void close() {
|
||||
releasable.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* State for a grouped {@code VALUES} aggregation. This implementation
|
||||
* emphasizes collect-time performance over the performance of rendering
|
||||
|
@ -120,10 +140,15 @@ class ValuesDoubleAggregator {
|
|||
* collector operation. But at least it's fairly simple.
|
||||
*/
|
||||
public static class GroupingState implements GroupingAggregatorState {
|
||||
private int maxGroupId = -1;
|
||||
private final BlockFactory blockFactory;
|
||||
private final LongLongHash values;
|
||||
|
||||
private GroupingState(BigArrays bigArrays) {
|
||||
values = new LongLongHash(1, bigArrays);
|
||||
private Sorted sortedForOrdinalMerging = null;
|
||||
|
||||
private GroupingState(DriverContext driverContext) {
|
||||
this.blockFactory = driverContext.blockFactory();
|
||||
values = new LongLongHash(1, driverContext.bigArrays());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -131,6 +156,11 @@ class ValuesDoubleAggregator {
|
|||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
void addValue(int groupId, double v) {
|
||||
values.add(groupId, Double.doubleToLongBits(v));
|
||||
maxGroupId = Math.max(maxGroupId, groupId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a {@link Block} with the unique values collected for the {@code #selected}
|
||||
* groups. This is the implementation of the final and intermediate results of the agg.
|
||||
|
@ -140,8 +170,15 @@ class ValuesDoubleAggregator {
|
|||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
try (var sorted = buildSorted(selected)) {
|
||||
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted buildSorted(IntVector selected) {
|
||||
long selectedCountsSize = 0;
|
||||
long idsSize = 0;
|
||||
Sorted sorted = null;
|
||||
try {
|
||||
/*
|
||||
* Get a count of all groups less than the maximum selected group. Count
|
||||
|
@ -216,12 +253,25 @@ class ValuesDoubleAggregator {
|
|||
ids[selectedCounts[group]++] = id;
|
||||
}
|
||||
}
|
||||
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
final long totalMemoryUsed = selectedCountsSize + idsSize;
|
||||
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
|
||||
return sorted;
|
||||
} finally {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
if (sorted == null) {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted sortedForOrdinalMerging(GroupingState other) {
|
||||
if (sortedForOrdinalMerging == null) {
|
||||
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
|
||||
sortedForOrdinalMerging = buildSorted(selected);
|
||||
}
|
||||
}
|
||||
return sortedForOrdinalMerging;
|
||||
}
|
||||
|
||||
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
|
||||
/*
|
||||
* Insert the ids in order.
|
||||
|
@ -234,11 +284,11 @@ class ValuesDoubleAggregator {
|
|||
int count = end - start;
|
||||
switch (count) {
|
||||
case 0 -> builder.appendNull();
|
||||
case 1 -> append(builder, ids[start]);
|
||||
case 1 -> builder.appendDouble(getValue(ids[start]));
|
||||
default -> {
|
||||
builder.beginPositionEntry();
|
||||
for (int i = start; i < end; i++) {
|
||||
append(builder, ids[i]);
|
||||
builder.appendDouble(getValue(ids[i]));
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
@ -249,9 +299,8 @@ class ValuesDoubleAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
private void append(DoubleBlock.Builder builder, int id) {
|
||||
double value = Double.longBitsToDouble(values.getKey2(id));
|
||||
builder.appendDouble(value);
|
||||
double getValue(int valueId) {
|
||||
return Double.longBitsToDouble(values.getKey2(valueId));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -261,7 +310,7 @@ class ValuesDoubleAggregator {
|
|||
|
||||
@Override
|
||||
public void close() {
|
||||
values.close();
|
||||
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.BlockFactory;
|
|||
import org.elasticsearch.compute.data.FloatBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
/**
|
||||
* Aggregates field values for float.
|
||||
|
@ -47,34 +49,32 @@ class ValuesFloatAggregator {
|
|||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays) {
|
||||
return new GroupingState(bigArrays);
|
||||
public static GroupingState initGrouping(DriverContext driverContext) {
|
||||
return new GroupingState(driverContext);
|
||||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, float v) {
|
||||
/*
|
||||
* Encode the groupId and value into a single long -
|
||||
* the top 32 bits for the group, the bottom 32 for the value.
|
||||
*/
|
||||
state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
|
||||
state.addValue(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, int valuesPosition) {
|
||||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.getFloat(i));
|
||||
state.addValue(groupId, values.getFloat(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
|
||||
for (int id = 0; id < state.values.size(); id++) {
|
||||
long both = state.values.get(id);
|
||||
int group = (int) (both >>> Float.SIZE);
|
||||
if (group == statePosition) {
|
||||
float value = Float.intBitsToFloat((int) both);
|
||||
combine(current, currentGroupId, value);
|
||||
}
|
||||
var sorted = state.sortedForOrdinalMerging(current);
|
||||
if (statePosition > state.maxGroupId) {
|
||||
return;
|
||||
}
|
||||
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
|
||||
var end = sorted.counts[statePosition];
|
||||
for (int i = start; i < end; i++) {
|
||||
int id = sorted.ids[i];
|
||||
current.addValue(currentGroupId, state.getValue(id));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,6 +117,20 @@ class ValuesFloatAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
|
||||
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
|
||||
* and then use it to iterate over the values in order.
|
||||
*
|
||||
* @param ids positions of the {@link GroupingState#values} to read.
|
||||
*/
|
||||
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
|
||||
@Override
|
||||
public void close() {
|
||||
releasable.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* State for a grouped {@code VALUES} aggregation. This implementation
|
||||
* emphasizes collect-time performance over the performance of rendering
|
||||
|
@ -125,10 +139,15 @@ class ValuesFloatAggregator {
|
|||
* collector operation. But at least it's fairly simple.
|
||||
*/
|
||||
public static class GroupingState implements GroupingAggregatorState {
|
||||
private int maxGroupId = -1;
|
||||
private final BlockFactory blockFactory;
|
||||
private final LongHash values;
|
||||
|
||||
private GroupingState(BigArrays bigArrays) {
|
||||
values = new LongHash(1, bigArrays);
|
||||
private Sorted sortedForOrdinalMerging = null;
|
||||
|
||||
private GroupingState(DriverContext driverContext) {
|
||||
this.blockFactory = driverContext.blockFactory();
|
||||
values = new LongHash(1, driverContext.bigArrays());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -136,6 +155,15 @@ class ValuesFloatAggregator {
|
|||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
void addValue(int groupId, float v) {
|
||||
/*
|
||||
* Encode the groupId and value into a single long -
|
||||
* the top 32 bits for the group, the bottom 32 for the value.
|
||||
*/
|
||||
values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
|
||||
maxGroupId = Math.max(maxGroupId, groupId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a {@link Block} with the unique values collected for the {@code #selected}
|
||||
* groups. This is the implementation of the final and intermediate results of the agg.
|
||||
|
@ -145,8 +173,15 @@ class ValuesFloatAggregator {
|
|||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
try (var sorted = buildSorted(selected)) {
|
||||
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted buildSorted(IntVector selected) {
|
||||
long selectedCountsSize = 0;
|
||||
long idsSize = 0;
|
||||
Sorted sorted = null;
|
||||
try {
|
||||
/*
|
||||
* Get a count of all groups less than the maximum selected group. Count
|
||||
|
@ -223,12 +258,25 @@ class ValuesFloatAggregator {
|
|||
ids[selectedCounts[group]++] = id;
|
||||
}
|
||||
}
|
||||
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
final long totalMemoryUsed = selectedCountsSize + idsSize;
|
||||
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
|
||||
return sorted;
|
||||
} finally {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
if (sorted == null) {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted sortedForOrdinalMerging(GroupingState other) {
|
||||
if (sortedForOrdinalMerging == null) {
|
||||
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
|
||||
sortedForOrdinalMerging = buildSorted(selected);
|
||||
}
|
||||
}
|
||||
return sortedForOrdinalMerging;
|
||||
}
|
||||
|
||||
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
|
||||
/*
|
||||
* Insert the ids in order.
|
||||
|
@ -241,11 +289,11 @@ class ValuesFloatAggregator {
|
|||
int count = end - start;
|
||||
switch (count) {
|
||||
case 0 -> builder.appendNull();
|
||||
case 1 -> append(builder, ids[start]);
|
||||
case 1 -> builder.appendFloat(getValue(ids[start]));
|
||||
default -> {
|
||||
builder.beginPositionEntry();
|
||||
for (int i = start; i < end; i++) {
|
||||
append(builder, ids[i]);
|
||||
builder.appendFloat(getValue(ids[i]));
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
@ -256,10 +304,9 @@ class ValuesFloatAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
private void append(FloatBlock.Builder builder, int id) {
|
||||
long both = values.get(id);
|
||||
float value = Float.intBitsToFloat((int) both);
|
||||
builder.appendFloat(value);
|
||||
float getValue(int valueId) {
|
||||
long both = values.get(valueId);
|
||||
return Float.intBitsToFloat((int) both);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -269,7 +316,7 @@ class ValuesFloatAggregator {
|
|||
|
||||
@Override
|
||||
public void close() {
|
||||
values.close();
|
||||
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ import org.elasticsearch.compute.data.BlockFactory;
|
|||
import org.elasticsearch.compute.data.IntBlock;
|
||||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
/**
|
||||
* Aggregates field values for int.
|
||||
|
@ -47,34 +49,32 @@ class ValuesIntAggregator {
|
|||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays) {
|
||||
return new GroupingState(bigArrays);
|
||||
public static GroupingState initGrouping(DriverContext driverContext) {
|
||||
return new GroupingState(driverContext);
|
||||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, int v) {
|
||||
/*
|
||||
* Encode the groupId and value into a single long -
|
||||
* the top 32 bits for the group, the bottom 32 for the value.
|
||||
*/
|
||||
state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
|
||||
state.addValue(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, int valuesPosition) {
|
||||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.getInt(i));
|
||||
state.addValue(groupId, values.getInt(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
|
||||
for (int id = 0; id < state.values.size(); id++) {
|
||||
long both = state.values.get(id);
|
||||
int group = (int) (both >>> Integer.SIZE);
|
||||
if (group == statePosition) {
|
||||
int value = (int) both;
|
||||
combine(current, currentGroupId, value);
|
||||
}
|
||||
var sorted = state.sortedForOrdinalMerging(current);
|
||||
if (statePosition > state.maxGroupId) {
|
||||
return;
|
||||
}
|
||||
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
|
||||
var end = sorted.counts[statePosition];
|
||||
for (int i = start; i < end; i++) {
|
||||
int id = sorted.ids[i];
|
||||
current.addValue(currentGroupId, state.getValue(id));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,6 +117,20 @@ class ValuesIntAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
|
||||
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
|
||||
* and then use it to iterate over the values in order.
|
||||
*
|
||||
* @param ids positions of the {@link GroupingState#values} to read.
|
||||
*/
|
||||
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
|
||||
@Override
|
||||
public void close() {
|
||||
releasable.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* State for a grouped {@code VALUES} aggregation. This implementation
|
||||
* emphasizes collect-time performance over the performance of rendering
|
||||
|
@ -125,10 +139,15 @@ class ValuesIntAggregator {
|
|||
* collector operation. But at least it's fairly simple.
|
||||
*/
|
||||
public static class GroupingState implements GroupingAggregatorState {
|
||||
private int maxGroupId = -1;
|
||||
private final BlockFactory blockFactory;
|
||||
private final LongHash values;
|
||||
|
||||
private GroupingState(BigArrays bigArrays) {
|
||||
values = new LongHash(1, bigArrays);
|
||||
private Sorted sortedForOrdinalMerging = null;
|
||||
|
||||
private GroupingState(DriverContext driverContext) {
|
||||
this.blockFactory = driverContext.blockFactory();
|
||||
values = new LongHash(1, driverContext.bigArrays());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -136,6 +155,15 @@ class ValuesIntAggregator {
|
|||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
void addValue(int groupId, int v) {
|
||||
/*
|
||||
* Encode the groupId and value into a single long -
|
||||
* the top 32 bits for the group, the bottom 32 for the value.
|
||||
*/
|
||||
values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
|
||||
maxGroupId = Math.max(maxGroupId, groupId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a {@link Block} with the unique values collected for the {@code #selected}
|
||||
* groups. This is the implementation of the final and intermediate results of the agg.
|
||||
|
@ -145,8 +173,15 @@ class ValuesIntAggregator {
|
|||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
try (var sorted = buildSorted(selected)) {
|
||||
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted buildSorted(IntVector selected) {
|
||||
long selectedCountsSize = 0;
|
||||
long idsSize = 0;
|
||||
Sorted sorted = null;
|
||||
try {
|
||||
/*
|
||||
* Get a count of all groups less than the maximum selected group. Count
|
||||
|
@ -223,12 +258,25 @@ class ValuesIntAggregator {
|
|||
ids[selectedCounts[group]++] = id;
|
||||
}
|
||||
}
|
||||
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
final long totalMemoryUsed = selectedCountsSize + idsSize;
|
||||
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
|
||||
return sorted;
|
||||
} finally {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
if (sorted == null) {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted sortedForOrdinalMerging(GroupingState other) {
|
||||
if (sortedForOrdinalMerging == null) {
|
||||
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
|
||||
sortedForOrdinalMerging = buildSorted(selected);
|
||||
}
|
||||
}
|
||||
return sortedForOrdinalMerging;
|
||||
}
|
||||
|
||||
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
|
||||
/*
|
||||
* Insert the ids in order.
|
||||
|
@ -241,11 +289,11 @@ class ValuesIntAggregator {
|
|||
int count = end - start;
|
||||
switch (count) {
|
||||
case 0 -> builder.appendNull();
|
||||
case 1 -> append(builder, ids[start]);
|
||||
case 1 -> builder.appendInt(getValue(ids[start]));
|
||||
default -> {
|
||||
builder.beginPositionEntry();
|
||||
for (int i = start; i < end; i++) {
|
||||
append(builder, ids[i]);
|
||||
builder.appendInt(getValue(ids[i]));
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
@ -256,10 +304,9 @@ class ValuesIntAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
private void append(IntBlock.Builder builder, int id) {
|
||||
long both = values.get(id);
|
||||
int value = (int) both;
|
||||
builder.appendInt(value);
|
||||
int getValue(int valueId) {
|
||||
long both = values.get(valueId);
|
||||
return (int) both;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -269,7 +316,7 @@ class ValuesIntAggregator {
|
|||
|
||||
@Override
|
||||
public void close() {
|
||||
values.close();
|
||||
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ import org.elasticsearch.compute.data.BlockFactory;
|
|||
import org.elasticsearch.compute.data.IntVector;
|
||||
import org.elasticsearch.compute.data.LongBlock;
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
/**
|
||||
* Aggregates field values for long.
|
||||
|
@ -48,28 +50,32 @@ class ValuesLongAggregator {
|
|||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays) {
|
||||
return new GroupingState(bigArrays);
|
||||
public static GroupingState initGrouping(DriverContext driverContext) {
|
||||
return new GroupingState(driverContext);
|
||||
}
|
||||
|
||||
public static void combine(GroupingState state, int groupId, long v) {
|
||||
state.values.add(groupId, v);
|
||||
state.addValue(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, int valuesPosition) {
|
||||
int start = values.getFirstValueIndex(valuesPosition);
|
||||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
combine(state, groupId, values.getLong(i));
|
||||
state.addValue(groupId, values.getLong(i));
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
|
||||
for (int id = 0; id < state.values.size(); id++) {
|
||||
if (state.values.getKey1(id) == statePosition) {
|
||||
long value = state.values.getKey2(id);
|
||||
combine(current, currentGroupId, value);
|
||||
}
|
||||
var sorted = state.sortedForOrdinalMerging(current);
|
||||
if (statePosition > state.maxGroupId) {
|
||||
return;
|
||||
}
|
||||
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
|
||||
var end = sorted.counts[statePosition];
|
||||
for (int i = start; i < end; i++) {
|
||||
int id = sorted.ids[i];
|
||||
current.addValue(currentGroupId, state.getValue(id));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,6 +118,20 @@ class ValuesLongAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
|
||||
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
|
||||
* and then use it to iterate over the values in order.
|
||||
*
|
||||
* @param ids positions of the {@link GroupingState#values} to read.
|
||||
*/
|
||||
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
|
||||
@Override
|
||||
public void close() {
|
||||
releasable.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* State for a grouped {@code VALUES} aggregation. This implementation
|
||||
* emphasizes collect-time performance over the performance of rendering
|
||||
|
@ -120,10 +140,15 @@ class ValuesLongAggregator {
|
|||
* collector operation. But at least it's fairly simple.
|
||||
*/
|
||||
public static class GroupingState implements GroupingAggregatorState {
|
||||
private int maxGroupId = -1;
|
||||
private final BlockFactory blockFactory;
|
||||
private final LongLongHash values;
|
||||
|
||||
private GroupingState(BigArrays bigArrays) {
|
||||
values = new LongLongHash(1, bigArrays);
|
||||
private Sorted sortedForOrdinalMerging = null;
|
||||
|
||||
private GroupingState(DriverContext driverContext) {
|
||||
this.blockFactory = driverContext.blockFactory();
|
||||
values = new LongLongHash(1, driverContext.bigArrays());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -131,6 +156,11 @@ class ValuesLongAggregator {
|
|||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
void addValue(int groupId, long v) {
|
||||
values.add(groupId, v);
|
||||
maxGroupId = Math.max(maxGroupId, groupId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a {@link Block} with the unique values collected for the {@code #selected}
|
||||
* groups. This is the implementation of the final and intermediate results of the agg.
|
||||
|
@ -140,8 +170,15 @@ class ValuesLongAggregator {
|
|||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
try (var sorted = buildSorted(selected)) {
|
||||
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted buildSorted(IntVector selected) {
|
||||
long selectedCountsSize = 0;
|
||||
long idsSize = 0;
|
||||
Sorted sorted = null;
|
||||
try {
|
||||
/*
|
||||
* Get a count of all groups less than the maximum selected group. Count
|
||||
|
@ -216,12 +253,25 @@ class ValuesLongAggregator {
|
|||
ids[selectedCounts[group]++] = id;
|
||||
}
|
||||
}
|
||||
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
final long totalMemoryUsed = selectedCountsSize + idsSize;
|
||||
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
|
||||
return sorted;
|
||||
} finally {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
if (sorted == null) {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted sortedForOrdinalMerging(GroupingState other) {
|
||||
if (sortedForOrdinalMerging == null) {
|
||||
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
|
||||
sortedForOrdinalMerging = buildSorted(selected);
|
||||
}
|
||||
}
|
||||
return sortedForOrdinalMerging;
|
||||
}
|
||||
|
||||
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
|
||||
/*
|
||||
* Insert the ids in order.
|
||||
|
@ -234,11 +284,11 @@ class ValuesLongAggregator {
|
|||
int count = end - start;
|
||||
switch (count) {
|
||||
case 0 -> builder.appendNull();
|
||||
case 1 -> append(builder, ids[start]);
|
||||
case 1 -> builder.appendLong(getValue(ids[start]));
|
||||
default -> {
|
||||
builder.beginPositionEntry();
|
||||
for (int i = start; i < end; i++) {
|
||||
append(builder, ids[i]);
|
||||
builder.appendLong(getValue(ids[i]));
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
@ -249,9 +299,8 @@ class ValuesLongAggregator {
|
|||
}
|
||||
}
|
||||
|
||||
private void append(LongBlock.Builder builder, int id) {
|
||||
long value = values.getKey2(id);
|
||||
builder.appendLong(value);
|
||||
long getValue(int valueId) {
|
||||
return values.getKey2(valueId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -261,7 +310,7 @@ class ValuesLongAggregator {
|
|||
|
||||
@Override
|
||||
public void close() {
|
||||
values.close();
|
||||
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ public final class ValuesBytesRefGroupingAggregatorFunction implements GroupingA
|
|||
|
||||
public static ValuesBytesRefGroupingAggregatorFunction create(List<Integer> channels,
|
||||
DriverContext driverContext) {
|
||||
return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext.bigArrays()), driverContext);
|
||||
return new ValuesBytesRefGroupingAggregatorFunction(channels, ValuesBytesRefAggregator.initGrouping(driverContext), driverContext);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
|
|
|
@ -42,7 +42,7 @@ public final class ValuesDoubleGroupingAggregatorFunction implements GroupingAgg
|
|||
|
||||
public static ValuesDoubleGroupingAggregatorFunction create(List<Integer> channels,
|
||||
DriverContext driverContext) {
|
||||
return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext.bigArrays()), driverContext);
|
||||
return new ValuesDoubleGroupingAggregatorFunction(channels, ValuesDoubleAggregator.initGrouping(driverContext), driverContext);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
|
|
|
@ -42,7 +42,7 @@ public final class ValuesFloatGroupingAggregatorFunction implements GroupingAggr
|
|||
|
||||
public static ValuesFloatGroupingAggregatorFunction create(List<Integer> channels,
|
||||
DriverContext driverContext) {
|
||||
return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext);
|
||||
return new ValuesFloatGroupingAggregatorFunction(channels, ValuesFloatAggregator.initGrouping(driverContext), driverContext);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
|
|
|
@ -41,7 +41,7 @@ public final class ValuesIntGroupingAggregatorFunction implements GroupingAggreg
|
|||
|
||||
public static ValuesIntGroupingAggregatorFunction create(List<Integer> channels,
|
||||
DriverContext driverContext) {
|
||||
return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext.bigArrays()), driverContext);
|
||||
return new ValuesIntGroupingAggregatorFunction(channels, ValuesIntAggregator.initGrouping(driverContext), driverContext);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
|
|
|
@ -42,7 +42,7 @@ public final class ValuesLongGroupingAggregatorFunction implements GroupingAggre
|
|||
|
||||
public static ValuesLongGroupingAggregatorFunction create(List<Integer> channels,
|
||||
DriverContext driverContext) {
|
||||
return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext.bigArrays()), driverContext);
|
||||
return new ValuesLongGroupingAggregatorFunction(channels, ValuesLongAggregator.initGrouping(driverContext), driverContext);
|
||||
}
|
||||
|
||||
public static List<IntermediateStateDesc> intermediateStateDesc() {
|
||||
|
|
|
@ -55,7 +55,7 @@ final class ValuesBytesRefAggregators {
|
|||
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
|
||||
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ final class ValuesBytesRefAggregators {
|
|||
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
|
||||
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ final class ValuesBytesRefAggregators {
|
|||
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
|
||||
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
|
||||
for (int v = valuesStart; v < valuesEnd; v++) {
|
||||
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
|
||||
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -135,7 +135,7 @@ final class ValuesBytesRefAggregators {
|
|||
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
|
||||
for (int g = groupStart; g < groupEnd; g++) {
|
||||
int groupId = groupIds.getInt(g);
|
||||
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
|
||||
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -150,7 +150,7 @@ final class ValuesBytesRefAggregators {
|
|||
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
|
||||
for (int g = groupStart; g < groupEnd; g++) {
|
||||
int groupId = groupIds.getInt(g);
|
||||
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
|
||||
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ final class ValuesBytesRefAggregators {
|
|||
public void add(int positionOffset, IntVector groupIds) {
|
||||
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
|
||||
int groupId = groupIds.getInt(groupPosition);
|
||||
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
|
||||
state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,12 +43,9 @@ $if(BytesRef)$
|
|||
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
|
||||
$endif$
|
||||
import org.elasticsearch.compute.operator.DriverContext;
|
||||
$if(BytesRef)$
|
||||
import org.elasticsearch.core.Releasable;
|
||||
import org.elasticsearch.core.Releasables;
|
||||
|
||||
$else$
|
||||
|
||||
$endif$
|
||||
/**
|
||||
* Aggregates field values for $type$.
|
||||
* This class is generated. Edit @{code X-ValuesAggregator.java.st} instead
|
||||
|
@ -90,8 +87,8 @@ $endif$
|
|||
return state.toBlock(driverContext.blockFactory());
|
||||
}
|
||||
|
||||
public static GroupingState initGrouping(BigArrays bigArrays) {
|
||||
return new GroupingState(bigArrays);
|
||||
public static GroupingState initGrouping(DriverContext driverContext) {
|
||||
return new GroupingState(driverContext);
|
||||
}
|
||||
|
||||
$if(BytesRef)$
|
||||
|
@ -113,25 +110,7 @@ $if(BytesRef)$
|
|||
$endif$
|
||||
|
||||
public static void combine(GroupingState state, int groupId, $type$ v) {
|
||||
$if(long)$
|
||||
state.values.add(groupId, v);
|
||||
$elseif(double)$
|
||||
state.values.add(groupId, Double.doubleToLongBits(v));
|
||||
$elseif(BytesRef)$
|
||||
state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v)));
|
||||
$elseif(int)$
|
||||
/*
|
||||
* Encode the groupId and value into a single long -
|
||||
* the top 32 bits for the group, the bottom 32 for the value.
|
||||
*/
|
||||
state.values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
|
||||
$elseif(float)$
|
||||
/*
|
||||
* Encode the groupId and value into a single long -
|
||||
* the top 32 bits for the group, the bottom 32 for the value.
|
||||
*/
|
||||
state.values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
|
||||
$endif$
|
||||
state.addValue(groupId, v);
|
||||
}
|
||||
|
||||
public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
|
||||
|
@ -142,37 +121,27 @@ $endif$
|
|||
int end = start + values.getValueCount(valuesPosition);
|
||||
for (int i = start; i < end; i++) {
|
||||
$if(BytesRef)$
|
||||
combine(state, groupId, values.getBytesRef(i, scratch));
|
||||
state.addValue(groupId, values.getBytesRef(i, scratch));
|
||||
$else$
|
||||
combine(state, groupId, values.get$Type$(i));
|
||||
state.addValue(groupId, values.get$Type$(i));
|
||||
$endif$
|
||||
}
|
||||
}
|
||||
|
||||
public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) {
|
||||
var sorted = state.sortedForOrdinalMerging(current);
|
||||
if (statePosition > state.maxGroupId) {
|
||||
return;
|
||||
}
|
||||
var start = statePosition > 0 ? sorted.counts[statePosition - 1] : 0;
|
||||
var end = sorted.counts[statePosition];
|
||||
for (int i = start; i < end; i++) {
|
||||
int id = sorted.ids[i];
|
||||
$if(BytesRef)$
|
||||
BytesRef scratch = new BytesRef();
|
||||
current.addValueOrdinal(currentGroupId, id);
|
||||
$else$
|
||||
current.addValue(currentGroupId, state.getValue(id));
|
||||
$endif$
|
||||
for (int id = 0; id < state.values.size(); id++) {
|
||||
$if(long||BytesRef)$
|
||||
if (state.values.getKey1(id) == statePosition) {
|
||||
long value = state.values.getKey2(id);
|
||||
$elseif(double)$
|
||||
if (state.values.getKey1(id) == statePosition) {
|
||||
double value = Double.longBitsToDouble(state.values.getKey2(id));
|
||||
$elseif(int)$
|
||||
long both = state.values.get(id);
|
||||
int group = (int) (both >>> Integer.SIZE);
|
||||
if (group == statePosition) {
|
||||
int value = (int) both;
|
||||
$elseif(float)$
|
||||
long both = state.values.get(id);
|
||||
int group = (int) (both >>> Float.SIZE);
|
||||
if (group == statePosition) {
|
||||
float value = Float.intBitsToFloat((int) both);
|
||||
$endif$
|
||||
combine(current, currentGroupId, $if(BytesRef)$state.bytes.get(value, scratch)$else$value$endif$);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -247,6 +216,24 @@ $endif$
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Values are collected in a hash. Iterating over them in order (row by row) to build the output,
|
||||
* or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
|
||||
* and then use it to iterate over the values in order.
|
||||
*
|
||||
* @param ids positions of the {@link GroupingState#values} to read.
|
||||
$if(BytesRef)$
|
||||
* If built from {@link GroupingState#sortedForOrdinalMerging(GroupingState)},
|
||||
* these are ordinals referring to the {@link GroupingState#bytes} in the target state.
|
||||
$endif$
|
||||
*/
|
||||
private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
|
||||
@Override
|
||||
public void close() {
|
||||
releasable.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* State for a grouped {@code VALUES} aggregation. This implementation
|
||||
* emphasizes collect-time performance over the performance of rendering
|
||||
|
@ -255,26 +242,31 @@ $endif$
|
|||
* collector operation. But at least it's fairly simple.
|
||||
*/
|
||||
public static class GroupingState implements GroupingAggregatorState {
|
||||
private int maxGroupId = -1;
|
||||
private final BlockFactory blockFactory;
|
||||
$if(long||double)$
|
||||
private final LongLongHash values;
|
||||
|
||||
$elseif(BytesRef)$
|
||||
final LongLongHash values;
|
||||
private final LongLongHash values;
|
||||
BytesRefHash bytes;
|
||||
|
||||
$elseif(int||float)$
|
||||
private final LongHash values;
|
||||
|
||||
$endif$
|
||||
private GroupingState(BigArrays bigArrays) {
|
||||
private Sorted sortedForOrdinalMerging = null;
|
||||
|
||||
private GroupingState(DriverContext driverContext) {
|
||||
this.blockFactory = driverContext.blockFactory();
|
||||
$if(long||double)$
|
||||
values = new LongLongHash(1, bigArrays);
|
||||
values = new LongLongHash(1, driverContext.bigArrays());
|
||||
$elseif(BytesRef)$
|
||||
LongLongHash _values = null;
|
||||
BytesRefHash _bytes = null;
|
||||
try {
|
||||
_values = new LongLongHash(1, bigArrays);
|
||||
_bytes = new BytesRefHash(1, bigArrays);
|
||||
_values = new LongLongHash(1, driverContext.bigArrays());
|
||||
_bytes = new BytesRefHash(1, driverContext.bigArrays());
|
||||
|
||||
values = _values;
|
||||
bytes = _bytes;
|
||||
|
@ -285,7 +277,7 @@ $elseif(BytesRef)$
|
|||
Releasables.closeExpectNoException(_values, _bytes);
|
||||
}
|
||||
$elseif(int||float)$
|
||||
values = new LongHash(1, bigArrays);
|
||||
values = new LongHash(1, driverContext.bigArrays());
|
||||
$endif$
|
||||
}
|
||||
|
||||
|
@ -294,6 +286,36 @@ $endif$
|
|||
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
|
||||
}
|
||||
|
||||
$if(BytesRef)$
|
||||
void addValueOrdinal(int groupId, long valueOrdinal) {
|
||||
values.add(groupId, valueOrdinal);
|
||||
maxGroupId = Math.max(maxGroupId, groupId);
|
||||
}
|
||||
|
||||
$endif$
|
||||
void addValue(int groupId, $type$ v) {
|
||||
$if(long)$
|
||||
values.add(groupId, v);
|
||||
$elseif(double)$
|
||||
values.add(groupId, Double.doubleToLongBits(v));
|
||||
$elseif(BytesRef)$
|
||||
values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v)));
|
||||
$elseif(int)$
|
||||
/*
|
||||
* Encode the groupId and value into a single long -
|
||||
* the top 32 bits for the group, the bottom 32 for the value.
|
||||
*/
|
||||
values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
|
||||
$elseif(float)$
|
||||
/*
|
||||
* Encode the groupId and value into a single long -
|
||||
* the top 32 bits for the group, the bottom 32 for the value.
|
||||
*/
|
||||
values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
|
||||
$endif$
|
||||
maxGroupId = Math.max(maxGroupId, groupId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a {@link Block} with the unique values collected for the {@code #selected}
|
||||
* groups. This is the implementation of the final and intermediate results of the agg.
|
||||
|
@ -303,8 +325,23 @@ $endif$
|
|||
return blockFactory.newConstantNullBlock(selected.getPositionCount());
|
||||
}
|
||||
|
||||
try (var sorted = buildSorted(selected)) {
|
||||
$if(BytesRef)$
|
||||
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
|
||||
return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
} else {
|
||||
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
}
|
||||
$else$
|
||||
return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
|
||||
$endif$
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted buildSorted(IntVector selected) {
|
||||
long selectedCountsSize = 0;
|
||||
long idsSize = 0;
|
||||
Sorted sorted = null;
|
||||
try {
|
||||
/*
|
||||
* Get a count of all groups less than the maximum selected group. Count
|
||||
|
@ -379,30 +416,56 @@ $endif$
|
|||
idsSize = adjust;
|
||||
int[] ids = new int[total];
|
||||
for (int id = 0; id < values.size(); id++) {
|
||||
$if(long||BytesRef||double)$
|
||||
$if(long||BytesRef||double)$
|
||||
int group = (int) values.getKey1(id);
|
||||
$elseif(float||int)$
|
||||
$elseif(float||int)$
|
||||
long both = values.get(id);
|
||||
int group = (int) (both >>> Float.SIZE);
|
||||
$endif$
|
||||
$endif$
|
||||
if (group < selectedCounts.length && selectedCounts[group] >= 0) {
|
||||
ids[selectedCounts[group]++] = id;
|
||||
}
|
||||
}
|
||||
$if(BytesRef)$
|
||||
if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
|
||||
return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
} else {
|
||||
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
}
|
||||
$else$
|
||||
return buildOutputBlock(blockFactory, selected, selectedCounts, ids);
|
||||
$endif$
|
||||
final long totalMemoryUsed = selectedCountsSize + idsSize;
|
||||
sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
|
||||
return sorted;
|
||||
} finally {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
if (sorted == null) {
|
||||
blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Sorted sortedForOrdinalMerging(GroupingState other) {
|
||||
if (sortedForOrdinalMerging == null) {
|
||||
try (var selected = IntVector.range(0, maxGroupId + 1, blockFactory)) {
|
||||
sortedForOrdinalMerging = buildSorted(selected);
|
||||
$if(BytesRef)$
|
||||
// hash all the bytes to the destination to avoid hashing them multiple times
|
||||
BytesRef scratch = new BytesRef();
|
||||
final int totalValue = Math.toIntExact(bytes.size());
|
||||
blockFactory.adjustBreaker((long) totalValue * Integer.BYTES);
|
||||
try {
|
||||
final int[] mappedIds = new int[totalValue];
|
||||
for (int i = 0; i < totalValue; i++) {
|
||||
var v = bytes.get(i, scratch);
|
||||
mappedIds[i] = Math.toIntExact(BlockHash.hashOrdToGroup(other.bytes.add(v)));
|
||||
}
|
||||
// no longer need the bytes
|
||||
bytes.close();
|
||||
bytes = null;
|
||||
for (int i = 0; i < sortedForOrdinalMerging.ids.length; i++) {
|
||||
sortedForOrdinalMerging.ids[i] = mappedIds[Math.toIntExact(values.getKey2(sortedForOrdinalMerging.ids[i]))];
|
||||
}
|
||||
} finally {
|
||||
blockFactory.adjustBreaker(-(long) totalValue * Integer.BYTES);
|
||||
}
|
||||
$endif$
|
||||
}
|
||||
}
|
||||
return sortedForOrdinalMerging;
|
||||
}
|
||||
|
||||
Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
|
||||
/*
|
||||
* Insert the ids in order.
|
||||
|
@ -418,11 +481,11 @@ $endif$
|
|||
int count = end - start;
|
||||
switch (count) {
|
||||
case 0 -> builder.appendNull();
|
||||
case 1 -> append(builder, ids[start]$if(BytesRef)$, scratch$endif$);
|
||||
case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$));
|
||||
default -> {
|
||||
builder.beginPositionEntry();
|
||||
for (int i = start; i < end; i++) {
|
||||
append(builder, ids[i]$if(BytesRef)$, scratch$endif$);
|
||||
builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$));
|
||||
}
|
||||
builder.endPositionEntry();
|
||||
}
|
||||
|
@ -470,29 +533,24 @@ $if(BytesRef)$
|
|||
}
|
||||
}
|
||||
}
|
||||
$endif$
|
||||
|
||||
private void append($Type$Block.Builder builder, int id, BytesRef scratch) {
|
||||
BytesRef value = bytes.get(values.getKey2(id), scratch);
|
||||
builder.appendBytesRef(value);
|
||||
}
|
||||
|
||||
$else$
|
||||
private void append($Type$Block.Builder builder, int id) {
|
||||
$if(long)$
|
||||
long value = values.getKey2(id);
|
||||
$type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) {
|
||||
$if(BytesRef)$
|
||||
return bytes.get(values.getKey2(valueId), scratch);
|
||||
$elseif(long)$
|
||||
return values.getKey2(valueId);
|
||||
$elseif(double)$
|
||||
double value = Double.longBitsToDouble(values.getKey2(id));
|
||||
return Double.longBitsToDouble(values.getKey2(valueId));
|
||||
$elseif(float)$
|
||||
long both = values.get(id);
|
||||
float value = Float.intBitsToFloat((int) both);
|
||||
long both = values.get(valueId);
|
||||
return Float.intBitsToFloat((int) both);
|
||||
$elseif(int)$
|
||||
long both = values.get(id);
|
||||
int value = (int) both;
|
||||
long both = values.get(valueId);
|
||||
return (int) both;
|
||||
$endif$
|
||||
builder.append$Type$(value);
|
||||
}
|
||||
|
||||
$endif$
|
||||
@Override
|
||||
public void enableGroupIdTracking(SeenGroupIds seen) {
|
||||
// we figure out seen values from nulls on the values block
|
||||
|
@ -501,9 +559,9 @@ $endif$
|
|||
@Override
|
||||
public void close() {
|
||||
$if(BytesRef)$
|
||||
Releasables.closeExpectNoException(values, bytes);
|
||||
Releasables.closeExpectNoException(values, bytes, sortedForOrdinalMerging);
|
||||
$else$
|
||||
values.close();
|
||||
Releasables.closeExpectNoException(values, sortedForOrdinalMerging);
|
||||
$endif$
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ public class HashAggregationOperator implements Operator {
|
|||
|
||||
private final BlockHash blockHash;
|
||||
|
||||
private final List<GroupingAggregator> aggregators;
|
||||
protected final List<GroupingAggregator> aggregators;
|
||||
|
||||
protected final DriverContext driverContext;
|
||||
|
||||
|
|
Loading…
Reference in New Issue