From 680c0a3c94f0bb84a2773bc9a95dc5399b6925fb Mon Sep 17 00:00:00 2001 From: pacien Date: Sun, 25 Nov 2018 16:45:35 +0100 Subject: Fix bitreader look-ahead overflow --- src/bitreader.nim | 52 ++++++++++++++++++---------------------------------- 1 file changed, 18 insertions(+), 34 deletions(-) (limited to 'src/bitreader.nim') diff --git a/src/bitreader.nim b/src/bitreader.nim index 757c1b3..7afb13d 100644 --- a/src/bitreader.nim +++ b/src/bitreader.nim @@ -17,49 +17,33 @@ import streams import integers -# Stream functions - -proc newEIO(msg: string): ref IOError = - new(result) - result.msg = msg - -proc read[T](s: Stream, t: typedesc[T]): T = - if readData(s, addr(result), sizeof(T)) != sizeof(T): - raise newEIO("cannot read from stream") - -proc peek[T](s: Stream, t: typedesc[T]): T = - if peekData(s, addr(result), sizeof(T)) != sizeof(T): - raise newEIO("cannot read from stream") - -# BitReader - type BitReader* = ref object stream: Stream bitOffset: int + overflowBuffer: uint8 proc bitReader*(stream: Stream): BitReader = - BitReader(stream: stream, bitOffset: 0) + BitReader(stream: stream, bitOffset: 0, overflowBuffer: 0) proc atEnd*(bitReader: BitReader): bool = - bitReader.stream.atEnd() + bitReader.bitOffset == 0 and bitReader.stream.atEnd() proc readBits*[T: SomeUnsignedInt](bitReader: BitReader, bits: int, to: typedesc[T]): T = - let targetBitLength = sizeof(T) * wordBitLength - if bits < 0 or bits > targetBitLength: - raise newException(RangeError, "invalid bit length") - elif bits == 0: - result = 0 - elif bits < targetBitLength - bitReader.bitOffset: - result = bitReader.stream.peek(T) shl (targetBitLength - bits - bitReader.bitOffset) shr (targetBitLength - bits) - elif bits == targetBitLength - bitReader.bitOffset: - result = bitReader.stream.read(T) shl (targetBitLength - bits - bitReader.bitOffset) shr (targetBitLength - bits) - else: - let rightBits = targetBitLength - bitReader.bitOffset - let leftBits = bits - rightBits - let right = bitReader.stream.read(T) shr bitReader.bitOffset - let left = bitReader.stream.peek(T) shl (targetBitLength - leftBits) shr (targetBitLength - bits) - result = left or right - bitReader.bitOffset = (bitReader.bitOffset + bits) mod wordBitLength + if bits < 0 or bits > sizeof(T) * wordBitLength: raise newException(RangeError, "invalid bit length") + if bits == 0: return 0 + var bitsRead = 0 + if bitReader.bitOffset > 0: + let bitsFromBuffer = min(bits, wordBitLength - bitReader.bitOffset) + result = (bitReader.overflowBuffer shr bitReader.bitOffset).leastSignificantBits(bitsFromBuffer) + bitReader.bitOffset = (bitReader.bitOffset + bitsFromBuffer) mod wordBitLength + bitsRead += bitsFromBuffer + while bits - bitsRead >= wordBitLength: + result = result or (bitReader.stream.readUint8().T shl bitsRead) + bitsRead += wordBitLength + if bits - bitsRead > 0: + bitReader.overflowBuffer = bitReader.stream.readUint8() + bitReader.bitOffset = bits - bitsRead + result = result or (bitReader.overflowBuffer.leastSignificantBits(bitReader.bitOffset).T shl bitsRead) proc readBool*(bitReader: BitReader): bool = bitReader.readBits(1, uint8) != 0 -- cgit v1.2.3