diff --git a/libraries/extractor/src/main/java/androidx/media3/extractor/ts/MpeghReader.java b/libraries/extractor/src/main/java/androidx/media3/extractor/ts/MpeghReader.java index 3a2fb5ebed..53844a357f 100644 --- a/libraries/extractor/src/main/java/androidx/media3/extractor/ts/MpeghReader.java +++ b/libraries/extractor/src/main/java/androidx/media3/extractor/ts/MpeghReader.java @@ -55,6 +55,8 @@ public final class MpeghReader implements ElementaryStreamReader { private static final int STATE_READING_PACKET_HEADER = 1; private static final int STATE_READING_PACKET_PAYLOAD = 2; + private static final int MAX_MHAS_PACKET_HEADER_SIZE = 15; + private @State int state; private @MonotonicNonNull String formatId; @@ -70,6 +72,7 @@ public final class MpeghReader implements ElementaryStreamReader { private int syncBytes; private final ParsableByteArray headerScratchBytes; + private final ParsableBitArray headerScratchBits; private boolean headerDataFinished; private final ParsableByteArray dataScratchBytes; @@ -77,7 +80,7 @@ public final class MpeghReader implements ElementaryStreamReader { private int payloadBytesRead; private int frameBytes; - @Nullable private MpeghUtil.MhasPacketHeader header; + private MpeghUtil.MhasPacketHeader header; private int samplingRate; private int standardFrameLength; private int truncationSamples; @@ -87,8 +90,10 @@ public final class MpeghReader implements ElementaryStreamReader { /** Constructs a new reader for MPEG-H elementary streams. */ public MpeghReader() { state = STATE_FINDING_SYNC; - headerScratchBytes = new ParsableByteArray(new byte[MpeghUtil.MAX_MHAS_PACKET_HEADER_SIZE]); + headerScratchBytes = new ParsableByteArray(new byte[MAX_MHAS_PACKET_HEADER_SIZE]); + headerScratchBits = new ParsableBitArray(); dataScratchBytes = new ParsableByteArray(); + header = new MpeghUtil.MhasPacketHeader(); samplingRate = C.RATE_UNSET_INT; standardFrameLength = C.LENGTH_UNSET; mainStreamLabel = C.INDEX_UNSET; @@ -102,9 +107,9 @@ public final class MpeghReader implements ElementaryStreamReader { state = STATE_FINDING_SYNC; syncBytes = 0; headerScratchBytes.setPosition(0); + headerScratchBits.setPosition(0); dataScratchBytes.setPosition(0); dataScratchBytes.setLimit(0); - header = null; headerDataFinished = false; payloadBytesRead = 0; frameBytes = 0; @@ -160,7 +165,7 @@ public final class MpeghReader implements ElementaryStreamReader { case STATE_READING_PACKET_HEADER: maybeAdjustHeaderScratchBuffer(); // read into header scratch buffer - if (continueRead(data, headerScratchBytes, MpeghUtil.MAX_MHAS_PACKET_HEADER_SIZE)) { + if (continueRead(data, headerScratchBytes, MAX_MHAS_PACKET_HEADER_SIZE)) { parseHeader(); // write the packet header to output headerScratchBytes.setPosition(0); @@ -177,11 +182,12 @@ public final class MpeghReader implements ElementaryStreamReader { } writeSampleData(data); if (payloadBytesRead == header.packetLength) { - ParsableBitArray bitArray = new ParsableBitArray(dataScratchBytes.getData()); if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_MPEGH3DACFG) { - parseConfig(bitArray); + parseConfig(new ParsableBitArray(dataScratchBytes.getData())); } else if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_AUDIOTRUNCATION) { - truncationSamples = MpeghUtil.parseAudioTruncationInfo(bitArray); + truncationSamples = + MpeghUtil.parseAudioTruncationInfo( + new ParsableBitArray(dataScratchBytes.getData())); } else if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_MPEGH3DAFRAME) { finalizeFrame(); } @@ -253,8 +259,9 @@ public final class MpeghReader implements ElementaryStreamReader { * @throws ParserException if a valid {@link MpeghUtil.Mpegh3daConfig} cannot be parsed. */ private void parseHeader() throws ParserException { + headerScratchBits.reset(headerScratchBytes.getData()); // parse the MHAS packet header - header = MpeghUtil.parseMhasPacketHeader(new ParsableBitArray(headerScratchBytes.getData())); + MpeghUtil.parseMhasPacketHeader(headerScratchBits, header); payloadBytesRead = 0; frameBytes += header.packetLength + header.headerLength; @@ -286,7 +293,7 @@ public final class MpeghReader implements ElementaryStreamReader { */ private void copyToDataScratchBuffer(ParsableByteArray data) { // read bytes from the end of the header scratch buffer into the data scratch buffer - if (headerScratchBytes.getPosition() != MpeghUtil.MAX_MHAS_PACKET_HEADER_SIZE) { + if (headerScratchBytes.getPosition() != MAX_MHAS_PACKET_HEADER_SIZE) { copyData(headerScratchBytes, dataScratchBytes, header.packetLength); } // read bytes from input data into the data scratch buffer @@ -331,7 +338,7 @@ public final class MpeghReader implements ElementaryStreamReader { private void writeSampleData(ParsableByteArray data) { int bytesToRead; // read bytes from the end of the header scratch buffer and write them into the output - if (headerScratchBytes.getPosition() != MpeghUtil.MAX_MHAS_PACKET_HEADER_SIZE) { + if (headerScratchBytes.getPosition() != MAX_MHAS_PACKET_HEADER_SIZE) { bytesToRead = min(headerScratchBytes.bytesLeft(), header.packetLength - payloadBytesRead); output.sampleData(headerScratchBytes, bytesToRead); payloadBytesRead += bytesToRead; diff --git a/libraries/extractor/src/main/java/androidx/media3/extractor/ts/MpeghUtil.java b/libraries/extractor/src/main/java/androidx/media3/extractor/ts/MpeghUtil.java index 0bd562342d..16232f945e 100644 --- a/libraries/extractor/src/main/java/androidx/media3/extractor/ts/MpeghUtil.java +++ b/libraries/extractor/src/main/java/androidx/media3/extractor/ts/MpeghUtil.java @@ -24,6 +24,8 @@ import androidx.media3.common.C; import androidx.media3.common.ParserException; import androidx.media3.common.util.ParsableBitArray; import androidx.media3.common.util.UnstableApi; +import com.google.common.math.IntMath; +import com.google.common.math.LongMath; import java.lang.annotation.Documented; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; @@ -36,8 +38,6 @@ import java.lang.annotation.Target; /** See ISO_IEC_23003-8;2022, 14.4.4. */ private static final int MHAS_SYNC_WORD = 0xC001A5; - public static final int MAX_MHAS_PACKET_HEADER_SIZE = 15; - /** * Returns whether the lower 3 bytes of the given integer matches an MHAS sync word. See * ISO_IEC_23008-3;2022, 14.4.4. @@ -48,26 +48,27 @@ import java.lang.annotation.Target; /** * Parses an MHAS packet header. See ISO_IEC_23008-3;2022, 14.2.1, Table 222. The reading position - * of {@code data} will be modified. + * of {@code data} will be modified to be just after the end of the MHAS packet header. * * @param data The data to parse, positioned at the start of the MHAS packet header. Must be * byte-aligned. - * @return The {@link MhasPacketHeader} info. + * @param header An instance of {@link MhasPacketHeader} that will be updated with the parsed + * information. * @throws ParserException if a valid {@link MhasPacketHeader} cannot be parsed. */ - public static MhasPacketHeader parseMhasPacketHeader(ParsableBitArray data) + public static void parseMhasPacketHeader(ParsableBitArray data, MhasPacketHeader header) throws ParserException { int dataStartPos = data.getBytePosition(); - @MhasPacketHeader.Type int packetType = readEscapedIntValue(data, 3, 8, 8); - long packetLabel = readEscapedLongValue(data, 2, 8, 32); + header.packetType = readEscapedIntValue(data, 3, 8, 8); + header.packetLabel = readEscapedLongValue(data, 2, 8, 32); - if (packetLabel > 0x10) { + if (header.packetLabel > 0x10) { throw ParserException.createForUnsupportedContainerFeature( - "Contains sub-stream with an invalid packet label " + packetLabel); + "Contains sub-stream with an invalid packet label " + header.packetLabel); } - if (packetLabel == 0) { - switch (packetType) { + if (header.packetLabel == 0) { + switch (header.packetType) { case MhasPacketHeader.PACTYP_MPEGH3DACFG: throw ParserException.createForMalformedContainer( "Mpegh3daConfig packet with invalid packet label 0", /* cause= */ null); @@ -82,11 +83,8 @@ import java.lang.annotation.Target; } } - int packetLength = readEscapedIntValue(data, 11, 24, 24); - - int headerLength = data.getBytePosition() - dataStartPos; - - return new MhasPacketHeader(packetType, packetLabel, packetLength, headerLength); + header.packetLength = readEscapedIntValue(data, 11, 24, 24); + header.headerLength = data.getBytePosition() - dataStartPos; } /** @@ -304,7 +302,7 @@ import java.lang.annotation.Target; /** * Obtains the number of truncated samples of the AudioTruncationInfo from an MPEG-H bit stream. * See ISO_IEC_23008-3;2022, 14.2.2, Table 225. The reading position of {@code data} will be - * modified. + * modified to be just after the end of the AudioTruncation packet payload. * * @param data The data to parse, positioned at the start of the payload of an AudioTruncation * packet. @@ -552,17 +550,19 @@ import java.lang.annotation.Target; * in reading a value greater than {@link Integer#MAX_VALUE}. */ private static int readEscapedIntValue(ParsableBitArray data, int bits1, int bits2, int bits3) { - // Check that a max possible escaped value doesn't exceed the max value that can be stored in an - // Integer. - checkArgument(Integer.MAX_VALUE - (1L << bits1) - (1L << bits2) - (1L << bits3) + 3 >= 0); + // Ensure that the calculated value will fit within the range of a Java {@code int}. + int maxBitCount = Math.max(Math.max(bits1, bits2), bits3); + checkArgument(maxBitCount <= Integer.SIZE - 1); + // Result is intentionally unused, checking if the operation causes overflow + int unused = + IntMath.checkedAdd(IntMath.checkedAdd((1 << bits1) - 1, (1 << bits2) - 1), (1 << bits3)); int value = data.readBits(bits1); - - if (value == (1L << bits1) - 1) { + if (value == (1 << bits1) - 1) { int valueAdd = data.readBits(bits2); value += valueAdd; - if (valueAdd == (1L << bits2) - 1) { + if (valueAdd == (1 << bits2) - 1) { valueAdd = data.readBits(bits3); value += valueAdd; } @@ -587,12 +587,15 @@ import java.lang.annotation.Target; * in reading a value greater than {@link Long#MAX_VALUE}. */ private static long readEscapedLongValue(ParsableBitArray data, int bits1, int bits2, int bits3) { - // Check that a max possible escaped value doesn't exceed the max value that can be stored in a - // Long. - checkArgument(Long.MAX_VALUE - (1L << bits1) - (1L << bits2) - (1L << bits3) + 3 >= 0); + // Ensure that the calculated value will fit within the range of a Java {@code long}. + int maxBitCount = Math.max(Math.max(bits1, bits2), bits3); + checkArgument(maxBitCount <= Long.SIZE - 1); + // Result is intentionally unused, checking if the operation causes overflow + long unused = + LongMath.checkedAdd( + LongMath.checkedAdd((1L << bits1) - 1, (1L << bits2) - 1), (1L << bits3)); long value = data.readBitsToLong(bits1); - if (value == (1L << bits1) - 1) { long valueAdd = data.readBitsToLong(bits2); value += valueAdd; @@ -665,25 +668,16 @@ import java.lang.annotation.Target; public static final int PACTYP_LOUDNESS = 22; /** The payload type in the actual packet. */ - public final @Type int packetType; + public @Type int packetType; /** A label indicating which packets belong together. */ - public final long packetLabel; + public long packetLabel; /** The length of MHAS packet payload in bytes. */ - public final int packetLength; + public int packetLength; /** The length of MHAS packet header in bytes. */ - public final int headerLength; - - /** Creates an instance. */ - public MhasPacketHeader( - @Type int packetType, long packetLabel, int packetLength, int headerLength) { - this.packetType = packetType; - this.packetLabel = packetLabel; - this.packetLength = packetLength; - this.headerLength = headerLength; - } + public int headerLength; } /** Represents an MPEG-H 3D audio configuration. */