Half-precision floating-point in Java

后端 未结 5 747
庸人自扰
庸人自扰 2020-12-07 17:26

Is there a Java library anywhere that can perform computations on IEEE 754 half-precision numbers or convert them to and from double-precision?

Either of these appro

5条回答
  •  长情又很酷
    2020-12-07 17:55

    I created a java class called, HalfPrecisionFloat, which uses x4u's solution. The class has convenience methods and error checking. It goes further and has methods for returning a Double and Float from the 2 byte half-precision value.

    Hopefully this will help someone.

    ==>

    import java.nio.ByteBuffer;
    
    /**
     * Accepts various forms of a floating point half-precision (2 byte) number 
     * and contains methods to convert to a
     * full-precision floating point number Float and Double instance.
     * 

    * This implemention was inspired by x4u who is a user contributing * to stackoverflow.com. * (https://stackoverflow.com/users/237321/x4u). * * @author dougestep */ public class HalfPrecisionFloat { private short halfPrecision; private Float fullPrecision; /** * Creates an instance of the class from the supplied the supplied * byte array. The byte array must be exactly two bytes in length. * * @param bytes the two-byte byte array. */ public HalfPrecisionFloat(byte[] bytes) { if (bytes.length != 2) { throw new IllegalArgumentException("The supplied byte array " + "must be exactly two bytes in length"); } final ByteBuffer buffer = ByteBuffer.wrap(bytes); this.halfPrecision = buffer.getShort(); } /** * Creates an instance of this class from the supplied short number. * * @param number the number defined as a short. */ public HalfPrecisionFloat(final short number) { this.halfPrecision = number; this.fullPrecision = toFullPrecision(); } /** * Creates an instance of this class from the supplied * full-precision floating point number. * * @param number the float number. */ public HalfPrecisionFloat(final float number) { if (number > Short.MAX_VALUE) { throw new IllegalArgumentException("The supplied float is too " + "large for a two byte representation"); } if (number < Short.MIN_VALUE) { throw new IllegalArgumentException("The supplied float is too " + "small for a two byte representation"); } final int val = fromFullPrecision(number); this.halfPrecision = (short) val; this.fullPrecision = number; } /** * Returns the half-precision float as a number defined as a short. * * @return the short. */ public short getHalfPrecisionAsShort() { return halfPrecision; } /** * Returns a full-precision floating pointing number from the * half-precision value assigned on this instance. * * @return the full-precision floating pointing number. */ public float getFullFloat() { if (fullPrecision == null) { fullPrecision = toFullPrecision(); } return fullPrecision; } /** * Returns a full-precision double floating point number from the * half-precision value assigned on this instance. * * @return the full-precision double floating pointing number. */ public double getFullDouble() { return new Double(getFullFloat()); } /** * Returns the full-precision float number from the half-precision * value assigned on this instance. * * @return the full-precision floating pointing number. */ private float toFullPrecision() { int mantisa = halfPrecision & 0x03ff; int exponent = halfPrecision & 0x7c00; if (exponent == 0x7c00) { exponent = 0x3fc00; } else if (exponent != 0) { exponent += 0x1c000; if (mantisa == 0 && exponent > 0x1c400) { return Float.intBitsToFloat( (halfPrecision & 0x8000) << 16 | exponent << 13 | 0x3ff); } } else if (mantisa != 0) { exponent = 0x1c400; do { mantisa <<= 1; exponent -= 0x400; } while ((mantisa & 0x400) == 0); mantisa &= 0x3ff; } return Float.intBitsToFloat( (halfPrecision & 0x8000) << 16 | (exponent | mantisa) << 13); } /** * Returns the integer representation of the supplied * full-precision floating pointing number. * * @param number the full-precision floating pointing number. * @return the integer representation. */ private int fromFullPrecision(final float number) { int fbits = Float.floatToIntBits(number); int sign = fbits >>> 16 & 0x8000; int val = (fbits & 0x7fffffff) + 0x1000; if (val >= 0x47800000) { if ((fbits & 0x7fffffff) >= 0x47800000) { if (val < 0x7f800000) { return sign | 0x7c00; } return sign | 0x7c00 | (fbits & 0x007fffff) >>> 13; } return sign | 0x7bff; } if (val >= 0x38800000) { return sign | val - 0x38000000 >>> 13; } if (val < 0x33000000) { return sign; } val = (fbits & 0x7fffffff) >>> 23; return sign | ((fbits & 0x7fffff | 0x800000) + (0x800000 >>> val - 102) >>> 126 - val); }

    And here's the unit tests

    import org.junit.Assert;
    import org.junit.Test;
    
    import java.nio.ByteBuffer;
    
    public class TestHalfPrecision {
    
      private byte[] simulateBytes(final float fullPrecision) {
        HalfPrecisionFloat halfFloat = new HalfPrecisionFloat(fullPrecision);
        short halfShort = halfFloat.getHalfPrecisionAsShort();
    
        ByteBuffer buffer = ByteBuffer.allocate(2);
        buffer.putShort(halfShort);
        return buffer.array();
      }
    
      @Test
      public void testHalfPrecisionToFloatApproach() {
        final float startingValue = 1.2f;
        final float closestValue = 1.2001953f;
        final short shortRepresentation = (short) 15565;
    
        byte[] bytes = simulateBytes(startingValue);
        HalfPrecisionFloat halfFloat = new HalfPrecisionFloat(bytes);
        final float retFloat = halfFloat.getFullFloat();
        Assert.assertEquals(new Float(closestValue), new Float(retFloat));
    
        HalfPrecisionFloat otherWay = new HalfPrecisionFloat(retFloat);
        final short shrtValue = otherWay.getHalfPrecisionAsShort();
        Assert.assertEquals(new Short(shortRepresentation), new Short(shrtValue));
    
        HalfPrecisionFloat backAgain = new HalfPrecisionFloat(shrtValue);
        final float backFlt = backAgain.getFullFloat();
        Assert.assertEquals(new Float(closestValue), new Float(backFlt));
    
        HalfPrecisionFloat dbl = new HalfPrecisionFloat(startingValue);
        final double retDbl = dbl.getFullDouble();
        Assert.assertEquals(new Double(startingValue), new Double(retDbl));
      }
    
      @Test(expected = IllegalArgumentException.class)
      public void testInvalidByteArray() {
        ByteBuffer buffer = ByteBuffer.allocate(4);
        buffer.putFloat(Float.MAX_VALUE);
        byte[] bytes = buffer.array();
    
        new HalfPrecisionFloat(bytes);
      }
    
      @Test(expected = IllegalArgumentException.class)
      public void testInvalidMaxFloat() {
        new HalfPrecisionFloat(Float.MAX_VALUE);
      }
    
      @Test(expected = IllegalArgumentException.class)
      public void testInvalidMinFloat() {
        new HalfPrecisionFloat(-35000);
      }
    
      @Test
      public void testCreateWithShort() {
        HalfPrecisionFloat sut = new HalfPrecisionFloat(Short.MAX_VALUE);
        Assert.assertEquals(Short.MAX_VALUE, sut.getHalfPrecisionAsShort());
      }
    }
    

提交回复
热议问题