diff --git a/MsgPack/Decoder.swift b/MsgPack/Decoder.swift index 4f4b126..ba8ef47 100644 --- a/MsgPack/Decoder.swift +++ b/MsgPack/Decoder.swift @@ -17,6 +17,15 @@ public class Decoder { public init() {} } +typealias PartiallyDecodedDictionary = [String:PartiallyDecodedValue] +enum PartiallyDecodedValue { + case constant(FormatID) + case fixedWidth(FormatID, pointer: Int) + case variableWidth(FormatID, pointer: Int, length: Int) + case array([PartiallyDecodedValue]) + case dictionary(PartiallyDecodedDictionary) +} + class IntermediateDecoder: Swift.Decoder { var codingPath = [CodingKey]() @@ -25,7 +34,7 @@ class IntermediateDecoder: Swift.Decoder { var storage: Data var offset = 0 - var dictionary = [String:(FormatID, Int)]() + var dictionary = PartiallyDecodedDictionary() init(with data: Data) { storage = data @@ -169,63 +178,78 @@ struct MsgPckKeyedDecodingContainer: KeyedDecodingContainerProtoco self.decoder = decoder self.base = base - let elementCount: Int - switch decoder.storage[base] { - case FormatID.fixMapRange: - elementCount = Int(decoder.storage[base] - FormatID.fixMap.rawValue) - self.base += 1 - case FormatID.map16.rawValue: - elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt16) - self.base += 3 - case FormatID.map32.rawValue: - elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt32) - self.base += 5 - default: - throw DecodingError.typeMismatch(Dictionary.self, .init(codingPath: codingPath, debugDescription: "Expected a MsgPack map format, but found 0x\(String(decoder.storage[base], radix: 16))")) - } - for cursor in 0 ..< elementCount { - let key = try Format.string(from: &decoder.storage, base: &self.base) - - let valueFormat: FormatID - if let valueFormatLookup = FormatID(rawValue: decoder.storage[self.base]) { - valueFormat = valueFormatLookup - } else { - switch decoder.storage[self.base] { - case FormatID.fixMapRange: - valueFormat = .fixMap - case FormatID.fixStringRange: - valueFormat = .fixString - case FormatID.positiveInt7Range: - valueFormat = .positiveInt7 - case FormatID.negativeInt5Range: - valueFormat = .negativeInt5 - default: - throw DecodingError.dataCorrupted(.init(codingPath: codingPath, debugDescription: "Unknown format: 0x\(String(decoder.storage[self.base], radix: 16))")) + if decoder.dictionary.count == 0 { + let elementCount: Int + switch decoder.storage[base] { + case FormatID.fixMapRange: + elementCount = Int(decoder.storage[base] - FormatID.fixMap.rawValue) + self.base += 1 + case FormatID.map16.rawValue: + elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt16) + self.base += 3 + case FormatID.map32.rawValue: + elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt32) + self.base += 5 + default: + throw DecodingError.typeMismatch(Dictionary.self, .init(codingPath: codingPath, debugDescription: "Expected a MsgPack map format, but found 0x\(String(decoder.storage[base], radix: 16))")) + } + for cursor in 0 ..< elementCount { + let key = try Format.string(from: &decoder.storage, base: &self.base) + + let valueFormat: FormatID + if let valueFormatLookup = FormatID(rawValue: decoder.storage[self.base]) { + valueFormat = valueFormatLookup + } else { + switch decoder.storage[self.base] { + case FormatID.fixMapRange: + valueFormat = .fixMap + case FormatID.fixStringRange: + valueFormat = .fixString + case FormatID.positiveInt7Range: + valueFormat = .positiveInt7 + case FormatID.negativeInt5Range: + valueFormat = .negativeInt5 + default: + throw DecodingError.dataCorrupted(.init(codingPath: codingPath, debugDescription: "Unknown format: 0x\(String(decoder.storage[self.base], radix: 16))")) + } + } + + switch valueFormat { + case .nil, .false, .true: + decoder.dictionary[key] = .constant(valueFormat) + self.base += 1 + case .positiveInt7, .negativeInt5: + decoder.dictionary[key] = .fixedWidth(valueFormat, pointer: self.base) + self.base += 1 + case .uInt8, .int8, .string8: + decoder.dictionary[key] = .fixedWidth(valueFormat, pointer: self.base + 1) + self.base += 2 + case .uInt16, .int16: + decoder.dictionary[key] = .fixedWidth(valueFormat, pointer: self.base + 1) + self.base += 3 + case .uInt32, .int32, .float32: + decoder.dictionary[key] = .fixedWidth(valueFormat, pointer: self.base + 1) + self.base += 5 + case .uInt64, .int64, .float64: + decoder.dictionary[key] = .fixedWidth(valueFormat, pointer: self.base + 1) + self.base += 9 + case .fixArray, .fixMap, .fixString: + let length = Int(decoder.storage[self.base] - valueFormat.rawValue) + self.base += 1 + decoder.dictionary[key] = .variableWidth(valueFormat, pointer: self.base, length: length) + self.base += length + case .array16, .map16, .string16: + let length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt16) + self.base += 3 + decoder.dictionary[key] = .variableWidth(valueFormat, pointer: self.base, length: length) + self.base += length + case .string32, .array32, .map32: + let length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt16) + self.base += 5 + decoder.dictionary[key] = .variableWidth(valueFormat, pointer: self.base, length: length) + self.base += length } } - - let length: Int - switch valueFormat { - case .positiveInt7, .negativeInt5, .fixArray, .fixMap, .fixString: - length = Int(decoder.storage[self.base] - valueFormat.rawValue) - self.base += length + 1 - case .uInt8, .int8, .string8: - length = Int(decoder.storage[self.base+1]) - self.base += length + 2 - case .uInt16, .int16, .string16, .array16, .map16: - length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt16) - self.base += length + 3 - case .uInt32, .int32, .string32, .array32, .map32, .float32: - length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt32) - self.base += length + 5 - case .uInt64, .int64, .float64: - length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt64) - self.base += length + 9 - case .nil, .false, .true: - length = 0 - self.base += 1 - } - decoder.dictionary[key] = (valueFormat, length) } } @@ -234,11 +258,21 @@ struct MsgPckKeyedDecodingContainer: KeyedDecodingContainerProtoco } func decodeNil(forKey key: K) throws -> Bool { - return decoder.dictionary[key.stringValue]?.0 == FormatID.nil + switch decoder.dictionary[key.stringValue]! { + case .constant(let format): + return format == .nil + default: + return false + } } func decode(_ type: Bool.Type, forKey key: K) throws -> Bool { - let format = decoder.dictionary[key.stringValue]!.0 + guard let value = decoder.dictionary[key.stringValue] else { + throw DecodingError.keyNotFound(key, .init(codingPath: codingPath, debugDescription: "Key not found")) + } + guard case let .constant(format) = value else { + throw DecodingError.typeMismatch(type, .init(codingPath: codingPath, debugDescription: "Expected bool but found \(value)")) + } switch format { case .true, .false: return format == .true @@ -247,24 +281,37 @@ struct MsgPckKeyedDecodingContainer: KeyedDecodingContainerProtoco } } + func formattedPointer(for key: K, format expectedFormat: FormatID, type: T.Type) throws -> Int { + guard let value = decoder.dictionary[key.stringValue] else { + throw DecodingError.keyNotFound(key, .init(codingPath: codingPath, debugDescription: "Key not found")) + } + guard case let .fixedWidth(format, pointer) = value else { + throw DecodingError.dataCorruptedError(forKey: key, in: self, debugDescription: "Expected fixed width value but found \(value)") + } + guard format == expectedFormat else { + throw DecodingError.typeMismatch(type, .init(codingPath: codingPath, debugDescription: "Expected \(expectedFormat) but found \(format)")) + } + return pointer + } + func decode(_ type: Int.Type, forKey key: K) throws -> Int { fatalError("not implemented") } func decode(_ type: Int8.Type, forKey key: K) throws -> Int8 { - fatalError("not implemented") + return decoder.storage.read(at: try formattedPointer(for: key, format: .int8, type: type)) } func decode(_ type: Int16.Type, forKey key: K) throws -> Int16 { - fatalError("not implemented") + return decoder.storage.bigEndianInteger(at: try formattedPointer(for: key, format: .int16, type: type)) } func decode(_ type: Int32.Type, forKey key: K) throws -> Int32 { - fatalError("not implemented") + return decoder.storage.bigEndianInteger(at: try formattedPointer(for: key, format: .int32, type: type)) } func decode(_ type: Int64.Type, forKey key: K) throws -> Int64 { - fatalError("not implemented") + return decoder.storage.bigEndianInteger(at: try formattedPointer(for: key, format: .int64, type: type)) } func decode(_ type: UInt.Type, forKey key: K) throws -> UInt { @@ -272,27 +319,29 @@ struct MsgPckKeyedDecodingContainer: KeyedDecodingContainerProtoco } func decode(_ type: UInt8.Type, forKey key: K) throws -> UInt8 { - fatalError("not implemented") + return decoder.storage[try formattedPointer(for: key, format: .uInt8, type: type)] } func decode(_ type: UInt16.Type, forKey key: K) throws -> UInt16 { - fatalError("not implemented") + return decoder.storage.bigEndianInteger(at: try formattedPointer(for: key, format: .uInt16, type: type)) } func decode(_ type: UInt32.Type, forKey key: K) throws -> UInt32 { - fatalError("not implemented") + return decoder.storage.bigEndianInteger(at: try formattedPointer(for: key, format: .uInt32, type: type)) } func decode(_ type: UInt64.Type, forKey key: K) throws -> UInt64 { - fatalError("not implemented") + return decoder.storage.bigEndianInteger(at: try formattedPointer(for: key, format: .uInt64, type: type)) } func decode(_ type: Float.Type, forKey key: K) throws -> Float { - fatalError("not implemented") + let pointer = try formattedPointer(for: key, format: .float32, type: type) + return Float(bitPattern: decoder.storage.bigEndianInteger(at: pointer) as UInt32) } func decode(_ type: Double.Type, forKey key: K) throws -> Double { - fatalError("not implemented") + let pointer = try formattedPointer(for: key, format: .float64, type: type) + return Double(bitPattern: decoder.storage.bigEndianInteger(at: pointer) as UInt64) } func decode(_ type: String.Type, forKey key: K) throws -> String { diff --git a/MsgPackTests/MsgPackTests.swift b/MsgPackTests/MsgPackTests.swift index e8a5ee0..ad9ba4e 100644 --- a/MsgPackTests/MsgPackTests.swift +++ b/MsgPackTests/MsgPackTests.swift @@ -113,6 +113,10 @@ class MsgPackTests: XCTestCase { struct Simple: Codable { let a: Bool let b: Bool + let c: Bool? + let d: Bool? + let e: Bool? + let f: Int8 } func roundtrip(value: T) throws -> T { @@ -121,7 +125,7 @@ class MsgPackTests: XCTestCase { func testExample() { do { - print("roundtrip:", try roundtrip(value: "Simple 🎁")) + print("roundtrip:", try roundtrip(value: Simple(a: true, b: false, c: true, d: false, e: nil, f: -8))) } catch { print(error) } diff --git a/Playground.playground/Contents.swift b/Playground.playground/Contents.swift index 48ae97b..29d7d8f 100644 --- a/Playground.playground/Contents.swift +++ b/Playground.playground/Contents.swift @@ -47,8 +47,9 @@ struct Simple: Codable { let c: Bool? let d: Bool? let e: Bool? + let f: Double } try roundtrip(value: -56.4) try roundtrip(value: "Hello world 😎") -try roundtrip(value: Simple(a: true, b: false, c: true, d: false, e: nil)) +try roundtrip(value: Simple(a: true, b: false, c: true, d: false, e: nil, f: 8.3))