Allow incubating Panama Vector in simdvec, and add vectorized ipByteBin (#112933)
Add support for vectorized ipByteBin. The structure of the implementation and loading framework mirror that of Lucene, but is simplified by avoiding reflective loading since ES has support for a MRJar section for 21. For now, we just disable warnings-as-errors in this small sourceset, since -Xlint:-incubating is only support since JDK 22. The number of source files is small here. Will investigate how to assert that just the single incubating warning is emitted by javac, at a later point.
This commit is contained in:
parent
e1bba9b390
commit
7decd52132
|
@ -0,0 +1,5 @@
|
|||
pr: 112933
|
||||
summary: "Allow incubating Panama Vector in simdvec, and add vectorized `ipByteBin`"
|
||||
area: Search
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -23,6 +23,19 @@ dependencies {
|
|||
}
|
||||
}
|
||||
|
||||
tasks.named("compileMain21Java").configure {
|
||||
options.compilerArgs << '--add-modules=jdk.incubator.vector'
|
||||
// we remove Werror, since incubating suppression (-Xlint:-incubating)
|
||||
// is only support since JDK 22
|
||||
options.compilerArgs -= '-Werror'
|
||||
}
|
||||
|
||||
test {
|
||||
if (JavaVersion.current().majorVersion.toInteger() >= 21) {
|
||||
jvmArgs '--add-modules=jdk.incubator.vector'
|
||||
}
|
||||
}
|
||||
|
||||
tasks.withType(CheckForbiddenApisTask).configureEach {
|
||||
replaceSignatureFiles 'jdk-signatures'
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
module org.elasticsearch.simdvec {
|
||||
requires org.elasticsearch.nativeaccess;
|
||||
requires org.apache.lucene.core;
|
||||
requires org.elasticsearch.logging;
|
||||
|
||||
exports org.elasticsearch.simdvec to org.elasticsearch.server;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec;
|
||||
|
||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
|
||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||
|
||||
import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
|
||||
|
||||
public class ESVectorUtil {
|
||||
|
||||
private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
|
||||
|
||||
public static long ipByteBinByte(byte[] q, byte[] d) {
|
||||
if (q.length != d.length * B_QUERY) {
|
||||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
|
||||
}
|
||||
return IMPL.ipByteBinByte(q, d);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
import org.apache.lucene.util.BitUtil;
|
||||
|
||||
final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
||||
|
||||
DefaultESVectorUtilSupport() {}
|
||||
|
||||
@Override
|
||||
public long ipByteBinByte(byte[] q, byte[] d) {
|
||||
return ipByteBinByteImpl(q, d);
|
||||
}
|
||||
|
||||
public static long ipByteBinByteImpl(byte[] q, byte[] d) {
|
||||
long ret = 0;
|
||||
int size = d.length;
|
||||
for (int i = 0; i < B_QUERY; i++) {
|
||||
int r = 0;
|
||||
long subRet = 0;
|
||||
for (final int upperBound = d.length & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
|
||||
subRet += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(q, i * size + r) & (int) BitUtil.VH_NATIVE_INT.get(d, r));
|
||||
}
|
||||
for (; r < d.length; r++) {
|
||||
subRet += Integer.bitCount((q[i * size + r] & d[r]) & 0xFF);
|
||||
}
|
||||
ret += subRet << i;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
final class DefaultESVectorizationProvider extends ESVectorizationProvider {
|
||||
private final ESVectorUtilSupport vectorUtilSupport;
|
||||
|
||||
DefaultESVectorizationProvider() {
|
||||
vectorUtilSupport = new DefaultESVectorUtilSupport();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ESVectorUtilSupport getVectorUtilSupport() {
|
||||
return vectorUtilSupport;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
public interface ESVectorUtilSupport {
|
||||
|
||||
short B_QUERY = 4;
|
||||
|
||||
long ipByteBinByte(byte[] q, byte[] d);
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public abstract class ESVectorizationProvider {
|
||||
|
||||
public static ESVectorizationProvider getInstance() {
|
||||
return Objects.requireNonNull(
|
||||
ESVectorizationProvider.Holder.INSTANCE,
|
||||
"call to getInstance() from subclass of VectorizationProvider"
|
||||
);
|
||||
}
|
||||
|
||||
ESVectorizationProvider() {}
|
||||
|
||||
public abstract ESVectorUtilSupport getVectorUtilSupport();
|
||||
|
||||
// visible for tests
|
||||
static ESVectorizationProvider lookup(boolean testMode) {
|
||||
return new DefaultESVectorizationProvider();
|
||||
}
|
||||
|
||||
/** This static holder class prevents classloading deadlock. */
|
||||
private static final class Holder {
|
||||
private Holder() {}
|
||||
|
||||
static final ESVectorizationProvider INSTANCE = lookup(false);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
import org.apache.lucene.util.Constants;
|
||||
import org.elasticsearch.logging.LogManager;
|
||||
import org.elasticsearch.logging.Logger;
|
||||
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
public abstract class ESVectorizationProvider {
|
||||
|
||||
protected static final Logger logger = LogManager.getLogger(ESVectorizationProvider.class);
|
||||
|
||||
public static ESVectorizationProvider getInstance() {
|
||||
return Objects.requireNonNull(
|
||||
ESVectorizationProvider.Holder.INSTANCE,
|
||||
"call to getInstance() from subclass of VectorizationProvider"
|
||||
);
|
||||
}
|
||||
|
||||
ESVectorizationProvider() {}
|
||||
|
||||
public abstract ESVectorUtilSupport getVectorUtilSupport();
|
||||
|
||||
// visible for tests
|
||||
static ESVectorizationProvider lookup(boolean testMode) {
|
||||
final int runtimeVersion = Runtime.version().feature();
|
||||
assert runtimeVersion >= 21;
|
||||
if (runtimeVersion <= 23) {
|
||||
// only use vector module with Hotspot VM
|
||||
if (Constants.IS_HOTSPOT_VM == false) {
|
||||
logger.warn("Java runtime is not using Hotspot VM; Java vector incubator API can't be enabled.");
|
||||
return new DefaultESVectorizationProvider();
|
||||
}
|
||||
// is the incubator module present and readable (JVM providers may to exclude them or it is
|
||||
// build with jlink)
|
||||
final var vectorMod = lookupVectorModule();
|
||||
if (vectorMod.isEmpty()) {
|
||||
logger.warn(
|
||||
"Java vector incubator module is not readable. "
|
||||
+ "For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API."
|
||||
);
|
||||
return new DefaultESVectorizationProvider();
|
||||
}
|
||||
vectorMod.ifPresent(ESVectorizationProvider.class.getModule()::addReads);
|
||||
var impl = new PanamaESVectorizationProvider();
|
||||
logger.info(
|
||||
String.format(
|
||||
Locale.ENGLISH,
|
||||
"Java vector incubator API enabled; uses preferredBitSize=%d",
|
||||
PanamaESVectorUtilSupport.VECTOR_BITSIZE
|
||||
)
|
||||
);
|
||||
return impl;
|
||||
} else {
|
||||
logger.warn(
|
||||
"You are running with unsupported Java "
|
||||
+ runtimeVersion
|
||||
+ ". To make full use of the Vector API, please update Elasticsearch."
|
||||
);
|
||||
}
|
||||
return new DefaultESVectorizationProvider();
|
||||
}
|
||||
|
||||
private static Optional<Module> lookupVectorModule() {
|
||||
return Optional.ofNullable(ESVectorizationProvider.class.getModule().getLayer())
|
||||
.orElse(ModuleLayer.boot())
|
||||
.findModule("jdk.incubator.vector");
|
||||
}
|
||||
|
||||
/** This static holder class prevents classloading deadlock. */
|
||||
private static final class Holder {
|
||||
private Holder() {}
|
||||
|
||||
static final ESVectorizationProvider INSTANCE = lookup(false);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
import jdk.incubator.vector.ByteVector;
|
||||
import jdk.incubator.vector.IntVector;
|
||||
import jdk.incubator.vector.LongVector;
|
||||
import jdk.incubator.vector.VectorOperators;
|
||||
import jdk.incubator.vector.VectorShape;
|
||||
import jdk.incubator.vector.VectorSpecies;
|
||||
|
||||
import org.apache.lucene.util.Constants;
|
||||
|
||||
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||
|
||||
static final int VECTOR_BITSIZE;
|
||||
|
||||
/** Whether integer vectors can be trusted to actually be fast. */
|
||||
static final boolean HAS_FAST_INTEGER_VECTORS;
|
||||
|
||||
static {
|
||||
// default to platform supported bitsize
|
||||
VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
|
||||
|
||||
// hotspot misses some SSE intrinsics, workaround it
|
||||
// to be fair, they do document this thing only works well with AVX2/AVX3 and Neon
|
||||
boolean isAMD64withoutAVX2 = Constants.OS_ARCH.equals("amd64") && VECTOR_BITSIZE < 256;
|
||||
HAS_FAST_INTEGER_VECTORS = isAMD64withoutAVX2 == false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ipByteBinByte(byte[] q, byte[] d) {
|
||||
// 128 / 8 == 16
|
||||
if (d.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
|
||||
if (VECTOR_BITSIZE >= 256) {
|
||||
return ipByteBin256(q, d);
|
||||
} else if (VECTOR_BITSIZE == 128) {
|
||||
return ipByteBin128(q, d);
|
||||
}
|
||||
}
|
||||
return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d);
|
||||
}
|
||||
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
||||
|
||||
static long ipByteBin256(byte[] q, byte[] d) {
|
||||
long subRet0 = 0;
|
||||
long subRet1 = 0;
|
||||
long subRet2 = 0;
|
||||
long subRet3 = 0;
|
||||
int i = 0;
|
||||
|
||||
if (d.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
|
||||
int limit = ByteVector.SPECIES_256.loopBound(d.length);
|
||||
var sum0 = LongVector.zero(LongVector.SPECIES_256);
|
||||
var sum1 = LongVector.zero(LongVector.SPECIES_256);
|
||||
var sum2 = LongVector.zero(LongVector.SPECIES_256);
|
||||
var sum3 = LongVector.zero(LongVector.SPECIES_256);
|
||||
for (; i < limit; i += ByteVector.SPECIES_256.length()) {
|
||||
var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
|
||||
var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length).reinterpretAsLongs();
|
||||
var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 2).reinterpretAsLongs();
|
||||
var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + d.length * 3).reinterpretAsLongs();
|
||||
var vd = ByteVector.fromArray(BYTE_SPECIES_256, d, i).reinterpretAsLongs();
|
||||
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
|
||||
}
|
||||
subRet0 += sum0.reduceLanes(VectorOperators.ADD);
|
||||
subRet1 += sum1.reduceLanes(VectorOperators.ADD);
|
||||
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
|
||||
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
|
||||
if (d.length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
|
||||
var sum0 = LongVector.zero(LongVector.SPECIES_128);
|
||||
var sum1 = LongVector.zero(LongVector.SPECIES_128);
|
||||
var sum2 = LongVector.zero(LongVector.SPECIES_128);
|
||||
var sum3 = LongVector.zero(LongVector.SPECIES_128);
|
||||
int limit = ByteVector.SPECIES_128.loopBound(d.length);
|
||||
for (; i < limit; i += ByteVector.SPECIES_128.length()) {
|
||||
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
|
||||
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsLongs();
|
||||
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsLongs();
|
||||
var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsLongs();
|
||||
var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsLongs();
|
||||
sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
|
||||
}
|
||||
subRet0 += sum0.reduceLanes(VectorOperators.ADD);
|
||||
subRet1 += sum1.reduceLanes(VectorOperators.ADD);
|
||||
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
|
||||
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
// tail as bytes
|
||||
for (; i < d.length; i++) {
|
||||
subRet0 += Integer.bitCount((q[i] & d[i]) & 0xFF);
|
||||
subRet1 += Integer.bitCount((q[i + d.length] & d[i]) & 0xFF);
|
||||
subRet2 += Integer.bitCount((q[i + 2 * d.length] & d[i]) & 0xFF);
|
||||
subRet3 += Integer.bitCount((q[i + 3 * d.length] & d[i]) & 0xFF);
|
||||
}
|
||||
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
|
||||
}
|
||||
|
||||
public static long ipByteBin128(byte[] q, byte[] d) {
|
||||
long subRet0 = 0;
|
||||
long subRet1 = 0;
|
||||
long subRet2 = 0;
|
||||
long subRet3 = 0;
|
||||
int i = 0;
|
||||
|
||||
var sum0 = IntVector.zero(IntVector.SPECIES_128);
|
||||
var sum1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
var sum2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
var sum3 = IntVector.zero(IntVector.SPECIES_128);
|
||||
int limit = ByteVector.SPECIES_128.loopBound(d.length);
|
||||
for (; i < limit; i += ByteVector.SPECIES_128.length()) {
|
||||
var vd = ByteVector.fromArray(BYTE_SPECIES_128, d, i).reinterpretAsInts();
|
||||
var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
|
||||
var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length).reinterpretAsInts();
|
||||
var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 2).reinterpretAsInts();
|
||||
var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + d.length * 3).reinterpretAsInts();
|
||||
sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT));
|
||||
sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT));
|
||||
}
|
||||
subRet0 += sum0.reduceLanes(VectorOperators.ADD);
|
||||
subRet1 += sum1.reduceLanes(VectorOperators.ADD);
|
||||
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
|
||||
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
|
||||
// tail as bytes
|
||||
for (; i < d.length; i++) {
|
||||
int dValue = d[i];
|
||||
subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
|
||||
subRet1 += Integer.bitCount((dValue & q[i + d.length]) & 0xFF);
|
||||
subRet2 += Integer.bitCount((dValue & q[i + 2 * d.length]) & 0xFF);
|
||||
subRet3 += Integer.bitCount((dValue & q[i + 3 * d.length]) & 0xFF);
|
||||
}
|
||||
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
final class PanamaESVectorizationProvider extends ESVectorizationProvider {
|
||||
|
||||
private final ESVectorUtilSupport vectorUtilSupport;
|
||||
|
||||
PanamaESVectorizationProvider() {
|
||||
vectorUtilSupport = new PanamaESVectorUtilSupport();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ESVectorUtilSupport getVectorUtilSupport() {
|
||||
return vectorUtilSupport;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec;
|
||||
|
||||
import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
|
||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
|
||||
|
||||
public class ESVectorUtilTests extends BaseVectorizationTests {
|
||||
|
||||
static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
|
||||
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
|
||||
|
||||
public void testIpByteBinInvariants() {
|
||||
int iterations = atLeast(10);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int size = randomIntBetween(1, 10);
|
||||
var d = new byte[size];
|
||||
var q = new byte[size * B_QUERY - 1];
|
||||
expectThrows(IllegalArgumentException.class, () -> ESVectorUtil.ipByteBinByte(q, d));
|
||||
}
|
||||
}
|
||||
|
||||
public void testBasicIpByteBin() {
|
||||
testBasicIpByteBinImpl(ESVectorUtil::ipByteBinByte);
|
||||
testBasicIpByteBinImpl(defaultedProvider.getVectorUtilSupport()::ipByteBinByte);
|
||||
testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
|
||||
}
|
||||
|
||||
interface IpByteBin {
|
||||
long apply(byte[] q, byte[] d);
|
||||
}
|
||||
|
||||
void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
|
||||
assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
|
||||
assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
|
||||
|
||||
var d = new byte[] { 1, 2, 3 };
|
||||
var q = new byte[] { 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3 };
|
||||
assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32
|
||||
assertEquals(60L, ipByteBinFunc.apply(q, d));
|
||||
|
||||
d = new byte[] { 1, 2, 3, 4 };
|
||||
q = new byte[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
|
||||
assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40
|
||||
assertEquals(75L, ipByteBinFunc.apply(q, d));
|
||||
|
||||
d = new byte[] { 1, 2, 3, 4, 5 };
|
||||
q = new byte[] { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5 };
|
||||
assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56
|
||||
assertEquals(105L, ipByteBinFunc.apply(q, d));
|
||||
|
||||
d = new byte[] { 1, 2, 3, 4, 5, 6 };
|
||||
q = new byte[] { 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6 };
|
||||
assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72
|
||||
assertEquals(135L, ipByteBinFunc.apply(q, d));
|
||||
|
||||
d = new byte[] { 1, 2, 3, 4, 5, 6, 7 };
|
||||
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 };
|
||||
assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96
|
||||
assertEquals(180L, ipByteBinFunc.apply(q, d));
|
||||
|
||||
d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
|
||||
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8 };
|
||||
assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104
|
||||
assertEquals(195L, ipByteBinFunc.apply(q, d));
|
||||
|
||||
d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
|
||||
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
|
||||
assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120
|
||||
assertEquals(225L, ipByteBinFunc.apply(q, d));
|
||||
}
|
||||
|
||||
public void testIpByteBin() {
|
||||
testIpByteBinImpl(ESVectorUtil::ipByteBinByte);
|
||||
testIpByteBinImpl(defaultedProvider.getVectorUtilSupport()::ipByteBinByte);
|
||||
testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
|
||||
}
|
||||
|
||||
void testIpByteBinImpl(IpByteBin ipByteBinFunc) {
|
||||
int iterations = atLeast(50);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int size = random().nextInt(5000);
|
||||
var d = new byte[size];
|
||||
var q = new byte[size * B_QUERY];
|
||||
random().nextBytes(d);
|
||||
random().nextBytes(q);
|
||||
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
|
||||
|
||||
Arrays.fill(d, Byte.MAX_VALUE);
|
||||
Arrays.fill(q, Byte.MAX_VALUE);
|
||||
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
|
||||
|
||||
Arrays.fill(d, Byte.MIN_VALUE);
|
||||
Arrays.fill(q, Byte.MIN_VALUE);
|
||||
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
|
||||
}
|
||||
}
|
||||
|
||||
static int scalarIpByteBin(byte[] q, byte[] d) {
|
||||
int res = 0;
|
||||
for (int i = 0; i < B_QUERY; i++) {
|
||||
res += (popcount(q, i * d.length, d, d.length) << i);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
|
||||
int res = 0;
|
||||
for (int j = 0; j < length; j++) {
|
||||
int value = (a[aOffset + j] & b[j]) & 0xFF;
|
||||
for (int k = 0; k < Byte.SIZE; k++) {
|
||||
if ((value & (1 << k)) != 0) {
|
||||
++res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
public class BaseVectorizationTests extends ESTestCase {
|
||||
|
||||
@Before
|
||||
public void sanity() {
|
||||
assert Runtime.version().feature() < 21 || ModuleLayer.boot().findModule("jdk.incubator.vector").isPresent();
|
||||
}
|
||||
|
||||
public static ESVectorizationProvider defaultProvider() {
|
||||
return new DefaultESVectorizationProvider();
|
||||
}
|
||||
|
||||
public static ESVectorizationProvider maybePanamaProvider() {
|
||||
return ESVectorizationProvider.lookup(true);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue