From b6fbe7b393e3a9d2683585b4108850481b50074b Mon Sep 17 00:00:00 2001 From: Damiaan Dufaux Date: Thu, 3 Aug 2017 11:54:13 +0200 Subject: [PATCH] Decode dictionaries containing bools --- MsgPack/Decoder.swift | 104 ++++++++++++++++++++++----- MsgPack/Format.swift | 91 ++++++++++++----------- MsgPackTests/MsgPackTests.swift | 13 +++- Playground.playground/Contents.swift | 16 +++-- 4 files changed, 161 insertions(+), 63 deletions(-) diff --git a/MsgPack/Decoder.swift b/MsgPack/Decoder.swift index 68fa402..ba20d65 100644 --- a/MsgPack/Decoder.swift +++ b/MsgPack/Decoder.swift @@ -24,13 +24,15 @@ class IntermediateDecoder: Swift.Decoder { var storage: Data var offset = 0 + + var dictionary = [String:(FormatID, Int)]() init(with data: Data) { storage = data } func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key : CodingKey { - fatalError() + return KeyedDecodingContainer(try MsgPckKeyedDecodingContainer(referringTo: self, at: offset, with: [])) } func unkeyedContainer() throws -> UnkeyedDecodingContainer { @@ -38,23 +40,20 @@ class IntermediateDecoder: Swift.Decoder { } func singleValueContainer() throws -> SingleValueDecodingContainer { - return MsgPckSingleValueDecodingContainer(refferingTo: self) + return MsgPckSingleValueDecodingContainer(decoder: self, base: 0, codingPath: []) } } struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer { - var codingPath = [CodingKey]() - var decoder: IntermediateDecoder - var base = 0 + let decoder: IntermediateDecoder + var base: Int + + var codingPath: [CodingKey] enum Error: Swift.Error { case invalidFormat(UInt8) } - init(refferingTo decoder: IntermediateDecoder) { - self.decoder = decoder - } - func decodeNil() -> Bool { return decoder.storage[base] == FormatID.nil.rawValue } @@ -147,8 +146,8 @@ struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer { return .init(bitPattern: bitPattern) } - func decode(_ type: String.Type) throws -> String { - return try Format.string(from: &decoder.storage) + mutating func decode(_ type: String.Type) throws -> String { + return try Format.string(from: &decoder.storage, base: &base) } func decode(_ type: T.Type) throws -> T where T : Decodable { @@ -159,18 +158,92 @@ struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer { struct MsgPckKeyedDecodingContainer: KeyedDecodingContainerProtocol { var codingPath: [CodingKey] - var allKeys: [K] + let decoder: IntermediateDecoder + var base = 0 + + var allKeys = [K]() + + init(referringTo decoder: IntermediateDecoder, at base: Int, with codingPath: [CodingKey]) throws { + self.codingPath = codingPath + self.decoder = decoder + self.base = base + + let elementCount: Int + switch decoder.storage[base] { + case FormatID.fixMapRange: + elementCount = Int(decoder.storage[base] - FormatID.fixMap.rawValue) + self.base += 1 + case FormatID.map16.rawValue: + elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt16) + self.base += 3 + case FormatID.map32.rawValue: + elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt32) + self.base += 5 + default: + throw DecodingError.typeMismatch(Dictionary.self, .init(codingPath: codingPath, debugDescription: "Expected a MsgPack map format, but found 0x\(String(decoder.storage[base], radix: 16))")) + } + for cursor in 0 ..< elementCount { + let key = try Format.string(from: &decoder.storage, base: &self.base) + + let valueFormat: FormatID + if let valueFormatLookup = FormatID(rawValue: decoder.storage[self.base]) { + valueFormat = valueFormatLookup + } else { + switch decoder.storage[self.base] { + case FormatID.fixMapRange: + valueFormat = .fixMap + case FormatID.fixStringRange: + valueFormat = .fixString + case FormatID.positiveInt7Range: + valueFormat = .positiveInt7 + case FormatID.negativeInt5Range: + valueFormat = .negativeInt5 + default: + throw DecodingError.dataCorrupted(.init(codingPath: codingPath, debugDescription: "Unknown format: 0x\(String(decoder.storage[self.base], radix: 16))")) + } + } + + let length: Int + switch valueFormat { + case .positiveInt7, .negativeInt5, .fixArray, .fixMap, .fixString: + length = Int(decoder.storage[self.base] - valueFormat.rawValue) + self.base += length + 1 + case .uInt8, .int8, .string8: + length = Int(decoder.storage[self.base+1]) + self.base += length + 2 + case .uInt16, .int16, .string16, .array16, .map16: + length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt16) + self.base += length + 3 + case .uInt32, .int32, .string32, .array32, .map32, .float32: + length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt32) + self.base += length + 5 + case .uInt64, .int64, .float64: + length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt64) + self.base += length + 9 + case .nil, .false, .true: + length = 0 + self.base += 1 + } + decoder.dictionary[key] = (valueFormat, length) + } + } func contains(_ key: K) -> Bool { - fatalError("not implemented") + return decoder.dictionary[key.stringValue] != nil } func decodeNil(forKey key: K) throws -> Bool { - fatalError("not implemented") + return decoder.dictionary[key.stringValue]?.0 == FormatID.nil } func decode(_ type: Bool.Type, forKey key: K) throws -> Bool { - fatalError("not implemented") + let format = decoder.dictionary[key.stringValue]!.0 + switch format { + case .true, .false: + return format == .true + default: + throw DecodingError.typeMismatch(type, .init(codingPath: codingPath, debugDescription: "Expected bool but found \(format)")) + } } func decode(_ type: Int.Type, forKey key: K) throws -> Int { @@ -246,5 +319,4 @@ struct MsgPckKeyedDecodingContainer: KeyedDecodingContainerProtoco } typealias Key = K - } diff --git a/MsgPack/Format.swift b/MsgPack/Format.swift index c5df725..4b3e687 100644 --- a/MsgPack/Format.swift +++ b/MsgPack/Format.swift @@ -52,22 +52,19 @@ public enum FormatID: UInt8 { case `false` = 0xC2 case `true` - case bin8 - case bin16 - case bin32 +// case bin8 +// case bin16 +// case bin32 +// +// case ext8 +// case ext16 +// case ext32 - case ext8 - case ext16 - case ext32 - - case float32 + case float32 = 0xCA case float64 - case firstPositiveInt7 = 0b00000000 - case lastPositiveInt7 = 0b01111111 - - case firstNegativeInt5 = 0b11100000 - case lastNegativeInt5 = 0b11111111 + case positiveInt7 = 0b00000000 + case negativeInt5 = 0b11100000 case uInt8 = 0xCC case uInt16 @@ -79,27 +76,31 @@ public enum FormatID: UInt8 { case int32 case int64 - case fixExt1 - case fixExt2 - case fixExt4 - case fixExt8 - case fixExt16 +// case fixExt1 +// case fixExt2 +// case fixExt4 +// case fixExt8 +// case fixExt16 - case fixStringStart = 0b10100000 - case fixStringEnd = 0b10111111 + case fixString = 0b10100000 case string8 = 0xD9 case string16 case string32 - case fixArrayStart = 0b10010000 - case fixArrayEnd = 0b10011111 + case fixArray = 0b10010000 case array16 = 0xDC case array32 - case fixMapStart = 0b10000000 - case fixMapEnd = 0b10001111 + case fixMap = 0b10000000 case map16 = 0xDE case map32 + + static let positiveInt7Range = FormatID.positiveInt7.rawValue ..< FormatID.fixMap.rawValue + static let negativeInt5Range = FormatID.negativeInt5.rawValue ..< 0b11111111 + + static let fixMapRange = FormatID.fixMap.rawValue ..< FormatID.fixArray.rawValue + static let fixArrayRange = FormatID.fixArray.rawValue ..< FormatID.fixString.rawValue + static let fixStringRange = FormatID.fixString.rawValue ..< FormatID.nil.rawValue } extension Format { @@ -116,9 +117,9 @@ extension Format { // MARK: Small integers (< 8 bit) case .positiveInt7(let value): - data.append(value | FormatID.firstPositiveInt7.rawValue) + data.append(value | FormatID.positiveInt7.rawValue) case .negativeInt5(let value): - data.append(value | FormatID.firstNegativeInt5.rawValue) + data.append(value | FormatID.negativeInt5.rawValue) // MARK: Unsigned integers case .uInt8(let value): @@ -179,7 +180,7 @@ extension Format { // MARK: Strings case .fixString(let utf8Data): precondition(utf8Data.count < 32, "fix strings cannot contain more than 31 bytes") - data.append( UInt8(utf8Data.count) | FormatID.fixStringStart.rawValue) + data.append( UInt8(utf8Data.count) | FormatID.fixString.rawValue) data.append(utf8Data) case .string8(let utf8Data): data.append(contentsOf: [FormatID.string8.rawValue, UInt8(utf8Data.count)]) @@ -200,7 +201,7 @@ extension Format { // MARK: Arrays case .fixArray(let array): precondition(array.count < 16, "fix arrays cannot contain more than 15 elements") - data.append( UInt8(array.count) | FormatID.fixArrayStart.rawValue) + data.append( UInt8(array.count) | FormatID.fixArray.rawValue) for element in array { element.appendTo(data: &data) } @@ -224,7 +225,7 @@ extension Format { // MARK: Maps case .fixMap(let pairs): precondition(pairs.count < 16, "fix maps cannot contain more than 15 key-value pairs") - data.append( UInt8(pairs.count) | FormatID.fixMapStart.rawValue) + data.append( UInt8(pairs.count) | FormatID.fixMap.rawValue) for (key, value) in pairs { key.appendTo(data: &data) value.appendTo(data: &data) @@ -329,42 +330,50 @@ extension Format { } extension Format { - static func string(from data: inout Data, offset: Int = 0) throws -> String { - switch data[offset] { - case FormatID.fixStringStart.rawValue ... FormatID.fixStringEnd.rawValue: - let length = Int(data[offset] & 0b00011111) + static func string(from data: inout Data, base: inout Int) throws -> String { + switch data[base] { + case FormatID.fixStringRange: + let length = Int(data[base] & 0b00011111) + base += 1 guard let string = data.withUnsafeMutableBytes({ - String(bytesNoCopy: $0.advanced(by: offset + 1), length: length, encoding: .utf8, freeWhenDone: false) + String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false) }) else { throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string")) } + base += length return string case FormatID.string8.rawValue: - let length = Int(data[offset + 1]) + let length = Int(data[base + 1]) + base += 2 guard let string = data.withUnsafeMutableBytes({ - String(bytesNoCopy: $0.advanced(by: offset + 2), length: length, encoding: .utf8, freeWhenDone: false) + String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false) }) else { throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string")) } + base += length return string case FormatID.string16.rawValue: - let length = Int(data.bigEndianInteger(at: offset) as UInt16) + let length = Int(data.bigEndianInteger(at: base + 1) as UInt16) + base += 3 guard let string = data.withUnsafeMutableBytes({ - String(bytesNoCopy: $0.advanced(by: offset + 3), length: length, encoding: .utf8, freeWhenDone: false) + String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false) }) else { throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string")) } + base += length return string case FormatID.string32.rawValue: - let length = Int(data.bigEndianInteger(at: offset) as UInt16) + let length = Int(data.bigEndianInteger(at: base + 1) as UInt16) + base += 5 guard let string = data.withUnsafeMutableBytes({ - String(bytesNoCopy: $0.advanced(by: offset + 5), length: length, encoding: .utf8, freeWhenDone: false) + String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false) }) else { throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string")) } + base += length return string default: - throw DecodingError.typeMismatch(String.self, .init(codingPath: [], debugDescription: "Wrong string format: \(data[offset])")) + throw DecodingError.typeMismatch(String.self, .init(codingPath: [], debugDescription: "Wrong string format: \(data[base])")) } } } diff --git a/MsgPackTests/MsgPackTests.swift b/MsgPackTests/MsgPackTests.swift index af1de1b..72a2955 100644 --- a/MsgPackTests/MsgPackTests.swift +++ b/MsgPackTests/MsgPackTests.swift @@ -16,11 +16,13 @@ class MsgPackTests: XCTestCase { // } var encoder: MsgPack.Encoder! + var decoder: MsgPack.Decoder! override func setUp() { super.setUp() encoder = Encoder() + decoder = Decoder() } @@ -108,9 +110,18 @@ class MsgPackTests: XCTestCase { } } + struct Simple: Codable { + let a: Bool + let b: Bool + } + + func roundtrip(value: T) throws -> T { + return try decoder.decode(T.self, from: encoder.encode(value)) + } + func testExample() { do { - print(try encoder.encode([[Int8(-3)], [5], [8]])) + print("roundtrip:", try roundtrip(value: Simple(a: false, b: false))) } catch { print(error) } diff --git a/Playground.playground/Contents.swift b/Playground.playground/Contents.swift index 697d4c0..cab1dbe 100644 --- a/Playground.playground/Contents.swift +++ b/Playground.playground/Contents.swift @@ -33,10 +33,7 @@ let graph = Graph( ] ) -let encodedGraph = try encoder.encode(graph) -for byte in encodedGraph { - print(String(byte, radix:16)) -} +try encoder.encode(graph) let decoder = Decoder() @@ -44,4 +41,13 @@ func roundtrip(value: T) throws -> T { return try decoder.decode(T.self, from: encoder.encode(value)) } -try roundtrip(value: "Hello") +struct Simple: Codable { + let a: Bool + let b: Bool + let c: Bool? + let d: Bool? + let e: Bool? +} + +try roundtrip(value: -56.4) +try roundtrip(value: Simple(a: true, b: false, c: true, d: false, e: nil))