Changes based on internal review

This commit is contained in:
Rohit Singh 2024-04-18 16:29:38 +01:00
parent e412987248
commit 5b3a50fca8
2 changed files with 51 additions and 50 deletions

View file

@ -55,6 +55,8 @@ public final class MpeghReader implements ElementaryStreamReader {
private static final int STATE_READING_PACKET_HEADER = 1; private static final int STATE_READING_PACKET_HEADER = 1;
private static final int STATE_READING_PACKET_PAYLOAD = 2; private static final int STATE_READING_PACKET_PAYLOAD = 2;
private static final int MAX_MHAS_PACKET_HEADER_SIZE = 15;
private @State int state; private @State int state;
private @MonotonicNonNull String formatId; private @MonotonicNonNull String formatId;
@ -70,6 +72,7 @@ public final class MpeghReader implements ElementaryStreamReader {
private int syncBytes; private int syncBytes;
private final ParsableByteArray headerScratchBytes; private final ParsableByteArray headerScratchBytes;
private final ParsableBitArray headerScratchBits;
private boolean headerDataFinished; private boolean headerDataFinished;
private final ParsableByteArray dataScratchBytes; private final ParsableByteArray dataScratchBytes;
@ -77,7 +80,7 @@ public final class MpeghReader implements ElementaryStreamReader {
private int payloadBytesRead; private int payloadBytesRead;
private int frameBytes; private int frameBytes;
@Nullable private MpeghUtil.MhasPacketHeader header; private MpeghUtil.MhasPacketHeader header;
private int samplingRate; private int samplingRate;
private int standardFrameLength; private int standardFrameLength;
private int truncationSamples; private int truncationSamples;
@ -87,8 +90,10 @@ public final class MpeghReader implements ElementaryStreamReader {
/** Constructs a new reader for MPEG-H elementary streams. */ /** Constructs a new reader for MPEG-H elementary streams. */
public MpeghReader() { public MpeghReader() {
state = STATE_FINDING_SYNC; state = STATE_FINDING_SYNC;
headerScratchBytes = new ParsableByteArray(new byte[MpeghUtil.MAX_MHAS_PACKET_HEADER_SIZE]); headerScratchBytes = new ParsableByteArray(new byte[MAX_MHAS_PACKET_HEADER_SIZE]);
headerScratchBits = new ParsableBitArray();
dataScratchBytes = new ParsableByteArray(); dataScratchBytes = new ParsableByteArray();
header = new MpeghUtil.MhasPacketHeader();
samplingRate = C.RATE_UNSET_INT; samplingRate = C.RATE_UNSET_INT;
standardFrameLength = C.LENGTH_UNSET; standardFrameLength = C.LENGTH_UNSET;
mainStreamLabel = C.INDEX_UNSET; mainStreamLabel = C.INDEX_UNSET;
@ -102,9 +107,9 @@ public final class MpeghReader implements ElementaryStreamReader {
state = STATE_FINDING_SYNC; state = STATE_FINDING_SYNC;
syncBytes = 0; syncBytes = 0;
headerScratchBytes.setPosition(0); headerScratchBytes.setPosition(0);
headerScratchBits.setPosition(0);
dataScratchBytes.setPosition(0); dataScratchBytes.setPosition(0);
dataScratchBytes.setLimit(0); dataScratchBytes.setLimit(0);
header = null;
headerDataFinished = false; headerDataFinished = false;
payloadBytesRead = 0; payloadBytesRead = 0;
frameBytes = 0; frameBytes = 0;
@ -160,7 +165,7 @@ public final class MpeghReader implements ElementaryStreamReader {
case STATE_READING_PACKET_HEADER: case STATE_READING_PACKET_HEADER:
maybeAdjustHeaderScratchBuffer(); maybeAdjustHeaderScratchBuffer();
// read into header scratch buffer // read into header scratch buffer
if (continueRead(data, headerScratchBytes, MpeghUtil.MAX_MHAS_PACKET_HEADER_SIZE)) { if (continueRead(data, headerScratchBytes, MAX_MHAS_PACKET_HEADER_SIZE)) {
parseHeader(); parseHeader();
// write the packet header to output // write the packet header to output
headerScratchBytes.setPosition(0); headerScratchBytes.setPosition(0);
@ -177,11 +182,12 @@ public final class MpeghReader implements ElementaryStreamReader {
} }
writeSampleData(data); writeSampleData(data);
if (payloadBytesRead == header.packetLength) { if (payloadBytesRead == header.packetLength) {
ParsableBitArray bitArray = new ParsableBitArray(dataScratchBytes.getData());
if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_MPEGH3DACFG) { if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_MPEGH3DACFG) {
parseConfig(bitArray); parseConfig(new ParsableBitArray(dataScratchBytes.getData()));
} else if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_AUDIOTRUNCATION) { } else if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_AUDIOTRUNCATION) {
truncationSamples = MpeghUtil.parseAudioTruncationInfo(bitArray); truncationSamples =
MpeghUtil.parseAudioTruncationInfo(
new ParsableBitArray(dataScratchBytes.getData()));
} else if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_MPEGH3DAFRAME) { } else if (header.packetType == MpeghUtil.MhasPacketHeader.PACTYP_MPEGH3DAFRAME) {
finalizeFrame(); finalizeFrame();
} }
@ -253,8 +259,9 @@ public final class MpeghReader implements ElementaryStreamReader {
* @throws ParserException if a valid {@link MpeghUtil.Mpegh3daConfig} cannot be parsed. * @throws ParserException if a valid {@link MpeghUtil.Mpegh3daConfig} cannot be parsed.
*/ */
private void parseHeader() throws ParserException { private void parseHeader() throws ParserException {
headerScratchBits.reset(headerScratchBytes.getData());
// parse the MHAS packet header // parse the MHAS packet header
header = MpeghUtil.parseMhasPacketHeader(new ParsableBitArray(headerScratchBytes.getData())); MpeghUtil.parseMhasPacketHeader(headerScratchBits, header);
payloadBytesRead = 0; payloadBytesRead = 0;
frameBytes += header.packetLength + header.headerLength; frameBytes += header.packetLength + header.headerLength;
@ -286,7 +293,7 @@ public final class MpeghReader implements ElementaryStreamReader {
*/ */
private void copyToDataScratchBuffer(ParsableByteArray data) { private void copyToDataScratchBuffer(ParsableByteArray data) {
// read bytes from the end of the header scratch buffer into the data scratch buffer // read bytes from the end of the header scratch buffer into the data scratch buffer
if (headerScratchBytes.getPosition() != MpeghUtil.MAX_MHAS_PACKET_HEADER_SIZE) { if (headerScratchBytes.getPosition() != MAX_MHAS_PACKET_HEADER_SIZE) {
copyData(headerScratchBytes, dataScratchBytes, header.packetLength); copyData(headerScratchBytes, dataScratchBytes, header.packetLength);
} }
// read bytes from input data into the data scratch buffer // read bytes from input data into the data scratch buffer
@ -331,7 +338,7 @@ public final class MpeghReader implements ElementaryStreamReader {
private void writeSampleData(ParsableByteArray data) { private void writeSampleData(ParsableByteArray data) {
int bytesToRead; int bytesToRead;
// read bytes from the end of the header scratch buffer and write them into the output // read bytes from the end of the header scratch buffer and write them into the output
if (headerScratchBytes.getPosition() != MpeghUtil.MAX_MHAS_PACKET_HEADER_SIZE) { if (headerScratchBytes.getPosition() != MAX_MHAS_PACKET_HEADER_SIZE) {
bytesToRead = min(headerScratchBytes.bytesLeft(), header.packetLength - payloadBytesRead); bytesToRead = min(headerScratchBytes.bytesLeft(), header.packetLength - payloadBytesRead);
output.sampleData(headerScratchBytes, bytesToRead); output.sampleData(headerScratchBytes, bytesToRead);
payloadBytesRead += bytesToRead; payloadBytesRead += bytesToRead;

View file

@ -24,6 +24,8 @@ import androidx.media3.common.C;
import androidx.media3.common.ParserException; import androidx.media3.common.ParserException;
import androidx.media3.common.util.ParsableBitArray; import androidx.media3.common.util.ParsableBitArray;
import androidx.media3.common.util.UnstableApi; import androidx.media3.common.util.UnstableApi;
import com.google.common.math.IntMath;
import com.google.common.math.LongMath;
import java.lang.annotation.Documented; import java.lang.annotation.Documented;
import java.lang.annotation.Retention; import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
@ -36,8 +38,6 @@ import java.lang.annotation.Target;
/** See ISO_IEC_23003-8;2022, 14.4.4. */ /** See ISO_IEC_23003-8;2022, 14.4.4. */
private static final int MHAS_SYNC_WORD = 0xC001A5; private static final int MHAS_SYNC_WORD = 0xC001A5;
public static final int MAX_MHAS_PACKET_HEADER_SIZE = 15;
/** /**
* Returns whether the lower 3 bytes of the given integer matches an MHAS sync word. See * Returns whether the lower 3 bytes of the given integer matches an MHAS sync word. See
* ISO_IEC_23008-3;2022, 14.4.4. * ISO_IEC_23008-3;2022, 14.4.4.
@ -48,26 +48,27 @@ import java.lang.annotation.Target;
/** /**
* Parses an MHAS packet header. See ISO_IEC_23008-3;2022, 14.2.1, Table 222. The reading position * Parses an MHAS packet header. See ISO_IEC_23008-3;2022, 14.2.1, Table 222. The reading position
* of {@code data} will be modified. * of {@code data} will be modified to be just after the end of the MHAS packet header.
* *
* @param data The data to parse, positioned at the start of the MHAS packet header. Must be * @param data The data to parse, positioned at the start of the MHAS packet header. Must be
* byte-aligned. * byte-aligned.
* @return The {@link MhasPacketHeader} info. * @param header An instance of {@link MhasPacketHeader} that will be updated with the parsed
* information.
* @throws ParserException if a valid {@link MhasPacketHeader} cannot be parsed. * @throws ParserException if a valid {@link MhasPacketHeader} cannot be parsed.
*/ */
public static MhasPacketHeader parseMhasPacketHeader(ParsableBitArray data) public static void parseMhasPacketHeader(ParsableBitArray data, MhasPacketHeader header)
throws ParserException { throws ParserException {
int dataStartPos = data.getBytePosition(); int dataStartPos = data.getBytePosition();
@MhasPacketHeader.Type int packetType = readEscapedIntValue(data, 3, 8, 8); header.packetType = readEscapedIntValue(data, 3, 8, 8);
long packetLabel = readEscapedLongValue(data, 2, 8, 32); header.packetLabel = readEscapedLongValue(data, 2, 8, 32);
if (packetLabel > 0x10) { if (header.packetLabel > 0x10) {
throw ParserException.createForUnsupportedContainerFeature( throw ParserException.createForUnsupportedContainerFeature(
"Contains sub-stream with an invalid packet label " + packetLabel); "Contains sub-stream with an invalid packet label " + header.packetLabel);
} }
if (packetLabel == 0) { if (header.packetLabel == 0) {
switch (packetType) { switch (header.packetType) {
case MhasPacketHeader.PACTYP_MPEGH3DACFG: case MhasPacketHeader.PACTYP_MPEGH3DACFG:
throw ParserException.createForMalformedContainer( throw ParserException.createForMalformedContainer(
"Mpegh3daConfig packet with invalid packet label 0", /* cause= */ null); "Mpegh3daConfig packet with invalid packet label 0", /* cause= */ null);
@ -82,11 +83,8 @@ import java.lang.annotation.Target;
} }
} }
int packetLength = readEscapedIntValue(data, 11, 24, 24); header.packetLength = readEscapedIntValue(data, 11, 24, 24);
header.headerLength = data.getBytePosition() - dataStartPos;
int headerLength = data.getBytePosition() - dataStartPos;
return new MhasPacketHeader(packetType, packetLabel, packetLength, headerLength);
} }
/** /**
@ -304,7 +302,7 @@ import java.lang.annotation.Target;
/** /**
* Obtains the number of truncated samples of the AudioTruncationInfo from an MPEG-H bit stream. * Obtains the number of truncated samples of the AudioTruncationInfo from an MPEG-H bit stream.
* See ISO_IEC_23008-3;2022, 14.2.2, Table 225. The reading position of {@code data} will be * See ISO_IEC_23008-3;2022, 14.2.2, Table 225. The reading position of {@code data} will be
* modified. * modified to be just after the end of the AudioTruncation packet payload.
* *
* @param data The data to parse, positioned at the start of the payload of an AudioTruncation * @param data The data to parse, positioned at the start of the payload of an AudioTruncation
* packet. * packet.
@ -552,17 +550,19 @@ import java.lang.annotation.Target;
* in reading a value greater than {@link Integer#MAX_VALUE}. * in reading a value greater than {@link Integer#MAX_VALUE}.
*/ */
private static int readEscapedIntValue(ParsableBitArray data, int bits1, int bits2, int bits3) { private static int readEscapedIntValue(ParsableBitArray data, int bits1, int bits2, int bits3) {
// Check that a max possible escaped value doesn't exceed the max value that can be stored in an // Ensure that the calculated value will fit within the range of a Java {@code int}.
// Integer. int maxBitCount = Math.max(Math.max(bits1, bits2), bits3);
checkArgument(Integer.MAX_VALUE - (1L << bits1) - (1L << bits2) - (1L << bits3) + 3 >= 0); checkArgument(maxBitCount <= Integer.SIZE - 1);
// Result is intentionally unused, checking if the operation causes overflow
int unused =
IntMath.checkedAdd(IntMath.checkedAdd((1 << bits1) - 1, (1 << bits2) - 1), (1 << bits3));
int value = data.readBits(bits1); int value = data.readBits(bits1);
if (value == (1 << bits1) - 1) {
if (value == (1L << bits1) - 1) {
int valueAdd = data.readBits(bits2); int valueAdd = data.readBits(bits2);
value += valueAdd; value += valueAdd;
if (valueAdd == (1L << bits2) - 1) { if (valueAdd == (1 << bits2) - 1) {
valueAdd = data.readBits(bits3); valueAdd = data.readBits(bits3);
value += valueAdd; value += valueAdd;
} }
@ -587,12 +587,15 @@ import java.lang.annotation.Target;
* in reading a value greater than {@link Long#MAX_VALUE}. * in reading a value greater than {@link Long#MAX_VALUE}.
*/ */
private static long readEscapedLongValue(ParsableBitArray data, int bits1, int bits2, int bits3) { private static long readEscapedLongValue(ParsableBitArray data, int bits1, int bits2, int bits3) {
// Check that a max possible escaped value doesn't exceed the max value that can be stored in a // Ensure that the calculated value will fit within the range of a Java {@code long}.
// Long. int maxBitCount = Math.max(Math.max(bits1, bits2), bits3);
checkArgument(Long.MAX_VALUE - (1L << bits1) - (1L << bits2) - (1L << bits3) + 3 >= 0); checkArgument(maxBitCount <= Long.SIZE - 1);
// Result is intentionally unused, checking if the operation causes overflow
long unused =
LongMath.checkedAdd(
LongMath.checkedAdd((1L << bits1) - 1, (1L << bits2) - 1), (1L << bits3));
long value = data.readBitsToLong(bits1); long value = data.readBitsToLong(bits1);
if (value == (1L << bits1) - 1) { if (value == (1L << bits1) - 1) {
long valueAdd = data.readBitsToLong(bits2); long valueAdd = data.readBitsToLong(bits2);
value += valueAdd; value += valueAdd;
@ -665,25 +668,16 @@ import java.lang.annotation.Target;
public static final int PACTYP_LOUDNESS = 22; public static final int PACTYP_LOUDNESS = 22;
/** The payload type in the actual packet. */ /** The payload type in the actual packet. */
public final @Type int packetType; public @Type int packetType;
/** A label indicating which packets belong together. */ /** A label indicating which packets belong together. */
public final long packetLabel; public long packetLabel;
/** The length of MHAS packet payload in bytes. */ /** The length of MHAS packet payload in bytes. */
public final int packetLength; public int packetLength;
/** The length of MHAS packet header in bytes. */ /** The length of MHAS packet header in bytes. */
public final int headerLength; public int headerLength;
/** Creates an instance. */
public MhasPacketHeader(
@Type int packetType, long packetLabel, int packetLength, int headerLength) {
this.packetType = packetType;
this.packetLabel = packetLabel;
this.packetLength = packetLength;
this.headerLength = headerLength;
}
} }
/** Represents an MPEG-H 3D audio configuration. */ /** Represents an MPEG-H 3D audio configuration. */