Speed up bit compared with floats or bytes script operations (#117199)
Instead of doing an "if" statement, which doesn't lend itself to vectorization, I switched to expand to the bits and multiply the 1s and 0s. This led to a marginal speed improvement on ARM. I expect that Panama vector could be used here to be even faster, but I didn't want to spend anymore time on this for the time being. ``` Benchmark (dims) Mode Cnt Score Error Units IpBitVectorScorerBenchmark.dotProductByteIfStatement 768 thrpt 5 2.952 ± 0.026 ops/us IpBitVectorScorerBenchmark.dotProductByteUnwrap 768 thrpt 5 4.017 ± 0.068 ops/us IpBitVectorScorerBenchmark.dotProductFloatIfStatement 768 thrpt 5 2.987 ± 0.124 ops/us IpBitVectorScorerBenchmark.dotProductFloatUnwrap 768 thrpt 5 4.726 ± 0.136 ops/us ``` Benchmark I used. https://gist.github.com/benwtrent/b0edb3975d2f03356c1a5ea84c72abc9
This commit is contained in:
parent
187935eb77
commit
e10fc3c90d
|
@ -0,0 +1,5 @@
|
|||
pr: 117199
|
||||
summary: Speed up bit compared with floats or bytes script operations
|
||||
area: Vector Search
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -61,17 +61,7 @@ public class ESVectorUtil {
|
|||
if (q.length != d.length * Byte.SIZE) {
|
||||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
|
||||
}
|
||||
int result = 0;
|
||||
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
|
||||
for (int i = 0; i < d.length; i++) {
|
||||
byte mask = d[i];
|
||||
for (int j = Byte.SIZE - 1; j >= 0; j--) {
|
||||
if ((mask & (1 << j)) != 0) {
|
||||
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
return IMPL.ipByteBit(q, d);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -87,16 +77,7 @@ public class ESVectorUtil {
|
|||
if (q.length != d.length * Byte.SIZE) {
|
||||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
|
||||
}
|
||||
float result = 0;
|
||||
for (int i = 0; i < d.length; i++) {
|
||||
byte mask = d[i];
|
||||
for (int j = Byte.SIZE - 1; j >= 0; j--) {
|
||||
if ((mask & (1 << j)) != 0) {
|
||||
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
return IMPL.ipFloatBit(q, d);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -10,9 +10,18 @@
|
|||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
import org.apache.lucene.util.BitUtil;
|
||||
import org.apache.lucene.util.Constants;
|
||||
|
||||
final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
||||
|
||||
private static float fma(float a, float b, float c) {
|
||||
if (Constants.HAS_FAST_SCALAR_FMA) {
|
||||
return Math.fma(a, b, c);
|
||||
} else {
|
||||
return a * b + c;
|
||||
}
|
||||
}
|
||||
|
||||
DefaultESVectorUtilSupport() {}
|
||||
|
||||
@Override
|
||||
|
@ -20,6 +29,62 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
return ipByteBinByteImpl(q, d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ipByteBit(byte[] q, byte[] d) {
|
||||
return ipByteBitImpl(q, d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float ipFloatBit(float[] q, byte[] d) {
|
||||
return ipFloatBitImpl(q, d);
|
||||
}
|
||||
|
||||
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
||||
assert q.length == d.length * Byte.SIZE;
|
||||
int acc0 = 0;
|
||||
int acc1 = 0;
|
||||
int acc2 = 0;
|
||||
int acc3 = 0;
|
||||
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
|
||||
for (int i = 0; i < d.length; i++) {
|
||||
byte mask = d[i];
|
||||
// Make sure its just 1 or 0
|
||||
|
||||
acc0 += q[i * Byte.SIZE + 0] * ((mask >> 7) & 1);
|
||||
acc1 += q[i * Byte.SIZE + 1] * ((mask >> 6) & 1);
|
||||
acc2 += q[i * Byte.SIZE + 2] * ((mask >> 5) & 1);
|
||||
acc3 += q[i * Byte.SIZE + 3] * ((mask >> 4) & 1);
|
||||
|
||||
acc0 += q[i * Byte.SIZE + 4] * ((mask >> 3) & 1);
|
||||
acc1 += q[i * Byte.SIZE + 5] * ((mask >> 2) & 1);
|
||||
acc2 += q[i * Byte.SIZE + 6] * ((mask >> 1) & 1);
|
||||
acc3 += q[i * Byte.SIZE + 7] * ((mask >> 0) & 1);
|
||||
}
|
||||
return acc0 + acc1 + acc2 + acc3;
|
||||
}
|
||||
|
||||
public static float ipFloatBitImpl(float[] q, byte[] d) {
|
||||
assert q.length == d.length * Byte.SIZE;
|
||||
float acc0 = 0;
|
||||
float acc1 = 0;
|
||||
float acc2 = 0;
|
||||
float acc3 = 0;
|
||||
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
|
||||
for (int i = 0; i < d.length; i++) {
|
||||
byte mask = d[i];
|
||||
acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
|
||||
acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);
|
||||
acc2 = fma(q[i * Byte.SIZE + 2], (mask >> 5) & 1, acc2);
|
||||
acc3 = fma(q[i * Byte.SIZE + 3], (mask >> 4) & 1, acc3);
|
||||
|
||||
acc0 = fma(q[i * Byte.SIZE + 4], (mask >> 3) & 1, acc0);
|
||||
acc1 = fma(q[i * Byte.SIZE + 5], (mask >> 2) & 1, acc1);
|
||||
acc2 = fma(q[i * Byte.SIZE + 6], (mask >> 1) & 1, acc2);
|
||||
acc3 = fma(q[i * Byte.SIZE + 7], (mask >> 0) & 1, acc3);
|
||||
}
|
||||
return acc0 + acc1 + acc2 + acc3;
|
||||
}
|
||||
|
||||
public static long ipByteBinByteImpl(byte[] q, byte[] d) {
|
||||
long ret = 0;
|
||||
int size = d.length;
|
||||
|
|
|
@ -14,4 +14,8 @@ public interface ESVectorUtilSupport {
|
|||
short B_QUERY = 4;
|
||||
|
||||
long ipByteBinByte(byte[] q, byte[] d);
|
||||
|
||||
int ipByteBit(byte[] q, byte[] d);
|
||||
|
||||
float ipFloatBit(float[] q, byte[] d);
|
||||
}
|
||||
|
|
|
@ -48,6 +48,16 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ipByteBit(byte[] q, byte[] d) {
|
||||
return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float ipFloatBit(float[] q, byte[] d) {
|
||||
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
|
||||
}
|
||||
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
||||
|
||||
|
|
Loading…
Reference in New Issue