Decode dictionaries containing bools

This commit is contained in:
Damiaan Dufaux 2017-08-03 11:54:13 +02:00
parent aca03f1952
commit b6fbe7b393
4 changed files with 161 additions and 63 deletions

View file

@ -24,13 +24,15 @@ class IntermediateDecoder: Swift.Decoder {
var storage: Data
var offset = 0
var dictionary = [String:(FormatID, Int)]()
init(with data: Data) {
storage = data
}
func container<Key>(keyedBy type: Key.Type) throws -> KeyedDecodingContainer<Key> where Key : CodingKey {
fatalError()
return KeyedDecodingContainer(try MsgPckKeyedDecodingContainer<Key>(referringTo: self, at: offset, with: []))
}
func unkeyedContainer() throws -> UnkeyedDecodingContainer {
@ -38,23 +40,20 @@ class IntermediateDecoder: Swift.Decoder {
}
func singleValueContainer() throws -> SingleValueDecodingContainer {
return MsgPckSingleValueDecodingContainer(refferingTo: self)
return MsgPckSingleValueDecodingContainer(decoder: self, base: 0, codingPath: [])
}
}
struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer {
var codingPath = [CodingKey]()
var decoder: IntermediateDecoder
var base = 0
let decoder: IntermediateDecoder
var base: Int
var codingPath: [CodingKey]
enum Error: Swift.Error {
case invalidFormat(UInt8)
}
init(refferingTo decoder: IntermediateDecoder) {
self.decoder = decoder
}
func decodeNil() -> Bool {
return decoder.storage[base] == FormatID.nil.rawValue
}
@ -147,8 +146,8 @@ struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer {
return .init(bitPattern: bitPattern)
}
func decode(_ type: String.Type) throws -> String {
return try Format.string(from: &decoder.storage)
mutating func decode(_ type: String.Type) throws -> String {
return try Format.string(from: &decoder.storage, base: &base)
}
func decode<T>(_ type: T.Type) throws -> T where T : Decodable {
@ -159,18 +158,92 @@ struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer {
struct MsgPckKeyedDecodingContainer<K: CodingKey>: KeyedDecodingContainerProtocol {
var codingPath: [CodingKey]
var allKeys: [K]
let decoder: IntermediateDecoder
var base = 0
var allKeys = [K]()
init(referringTo decoder: IntermediateDecoder, at base: Int, with codingPath: [CodingKey]) throws {
self.codingPath = codingPath
self.decoder = decoder
self.base = base
let elementCount: Int
switch decoder.storage[base] {
case FormatID.fixMapRange:
elementCount = Int(decoder.storage[base] - FormatID.fixMap.rawValue)
self.base += 1
case FormatID.map16.rawValue:
elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt16)
self.base += 3
case FormatID.map32.rawValue:
elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt32)
self.base += 5
default:
throw DecodingError.typeMismatch(Dictionary<String,Any>.self, .init(codingPath: codingPath, debugDescription: "Expected a MsgPack map format, but found 0x\(String(decoder.storage[base], radix: 16))"))
}
for cursor in 0 ..< elementCount {
let key = try Format.string(from: &decoder.storage, base: &self.base)
let valueFormat: FormatID
if let valueFormatLookup = FormatID(rawValue: decoder.storage[self.base]) {
valueFormat = valueFormatLookup
} else {
switch decoder.storage[self.base] {
case FormatID.fixMapRange:
valueFormat = .fixMap
case FormatID.fixStringRange:
valueFormat = .fixString
case FormatID.positiveInt7Range:
valueFormat = .positiveInt7
case FormatID.negativeInt5Range:
valueFormat = .negativeInt5
default:
throw DecodingError.dataCorrupted(.init(codingPath: codingPath, debugDescription: "Unknown format: 0x\(String(decoder.storage[self.base], radix: 16))"))
}
}
let length: Int
switch valueFormat {
case .positiveInt7, .negativeInt5, .fixArray, .fixMap, .fixString:
length = Int(decoder.storage[self.base] - valueFormat.rawValue)
self.base += length + 1
case .uInt8, .int8, .string8:
length = Int(decoder.storage[self.base+1])
self.base += length + 2
case .uInt16, .int16, .string16, .array16, .map16:
length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt16)
self.base += length + 3
case .uInt32, .int32, .string32, .array32, .map32, .float32:
length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt32)
self.base += length + 5
case .uInt64, .int64, .float64:
length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt64)
self.base += length + 9
case .nil, .false, .true:
length = 0
self.base += 1
}
decoder.dictionary[key] = (valueFormat, length)
}
}
func contains(_ key: K) -> Bool {
fatalError("not implemented")
return decoder.dictionary[key.stringValue] != nil
}
func decodeNil(forKey key: K) throws -> Bool {
fatalError("not implemented")
return decoder.dictionary[key.stringValue]?.0 == FormatID.nil
}
func decode(_ type: Bool.Type, forKey key: K) throws -> Bool {
fatalError("not implemented")
let format = decoder.dictionary[key.stringValue]!.0
switch format {
case .true, .false:
return format == .true
default:
throw DecodingError.typeMismatch(type, .init(codingPath: codingPath, debugDescription: "Expected bool but found \(format)"))
}
}
func decode(_ type: Int.Type, forKey key: K) throws -> Int {
@ -246,5 +319,4 @@ struct MsgPckKeyedDecodingContainer<K: CodingKey>: KeyedDecodingContainerProtoco
}
typealias Key = K
}

View file

@ -52,22 +52,19 @@ public enum FormatID: UInt8 {
case `false` = 0xC2
case `true`
case bin8
case bin16
case bin32
// case bin8
// case bin16
// case bin32
//
// case ext8
// case ext16
// case ext32
case ext8
case ext16
case ext32
case float32
case float32 = 0xCA
case float64
case firstPositiveInt7 = 0b00000000
case lastPositiveInt7 = 0b01111111
case firstNegativeInt5 = 0b11100000
case lastNegativeInt5 = 0b11111111
case positiveInt7 = 0b00000000
case negativeInt5 = 0b11100000
case uInt8 = 0xCC
case uInt16
@ -79,27 +76,31 @@ public enum FormatID: UInt8 {
case int32
case int64
case fixExt1
case fixExt2
case fixExt4
case fixExt8
case fixExt16
// case fixExt1
// case fixExt2
// case fixExt4
// case fixExt8
// case fixExt16
case fixStringStart = 0b10100000
case fixStringEnd = 0b10111111
case fixString = 0b10100000
case string8 = 0xD9
case string16
case string32
case fixArrayStart = 0b10010000
case fixArrayEnd = 0b10011111
case fixArray = 0b10010000
case array16 = 0xDC
case array32
case fixMapStart = 0b10000000
case fixMapEnd = 0b10001111
case fixMap = 0b10000000
case map16 = 0xDE
case map32
static let positiveInt7Range = FormatID.positiveInt7.rawValue ..< FormatID.fixMap.rawValue
static let negativeInt5Range = FormatID.negativeInt5.rawValue ..< 0b11111111
static let fixMapRange = FormatID.fixMap.rawValue ..< FormatID.fixArray.rawValue
static let fixArrayRange = FormatID.fixArray.rawValue ..< FormatID.fixString.rawValue
static let fixStringRange = FormatID.fixString.rawValue ..< FormatID.nil.rawValue
}
extension Format {
@ -116,9 +117,9 @@ extension Format {
// MARK: Small integers (< 8 bit)
case .positiveInt7(let value):
data.append(value | FormatID.firstPositiveInt7.rawValue)
data.append(value | FormatID.positiveInt7.rawValue)
case .negativeInt5(let value):
data.append(value | FormatID.firstNegativeInt5.rawValue)
data.append(value | FormatID.negativeInt5.rawValue)
// MARK: Unsigned integers
case .uInt8(let value):
@ -179,7 +180,7 @@ extension Format {
// MARK: Strings
case .fixString(let utf8Data):
precondition(utf8Data.count < 32, "fix strings cannot contain more than 31 bytes")
data.append( UInt8(utf8Data.count) | FormatID.fixStringStart.rawValue)
data.append( UInt8(utf8Data.count) | FormatID.fixString.rawValue)
data.append(utf8Data)
case .string8(let utf8Data):
data.append(contentsOf: [FormatID.string8.rawValue, UInt8(utf8Data.count)])
@ -200,7 +201,7 @@ extension Format {
// MARK: Arrays
case .fixArray(let array):
precondition(array.count < 16, "fix arrays cannot contain more than 15 elements")
data.append( UInt8(array.count) | FormatID.fixArrayStart.rawValue)
data.append( UInt8(array.count) | FormatID.fixArray.rawValue)
for element in array {
element.appendTo(data: &data)
}
@ -224,7 +225,7 @@ extension Format {
// MARK: Maps
case .fixMap(let pairs):
precondition(pairs.count < 16, "fix maps cannot contain more than 15 key-value pairs")
data.append( UInt8(pairs.count) | FormatID.fixMapStart.rawValue)
data.append( UInt8(pairs.count) | FormatID.fixMap.rawValue)
for (key, value) in pairs {
key.appendTo(data: &data)
value.appendTo(data: &data)
@ -329,42 +330,50 @@ extension Format {
}
extension Format {
static func string(from data: inout Data, offset: Int = 0) throws -> String {
switch data[offset] {
case FormatID.fixStringStart.rawValue ... FormatID.fixStringEnd.rawValue:
let length = Int(data[offset] & 0b00011111)
static func string(from data: inout Data, base: inout Int) throws -> String {
switch data[base] {
case FormatID.fixStringRange:
let length = Int(data[base] & 0b00011111)
base += 1
guard let string = data.withUnsafeMutableBytes({
String(bytesNoCopy: $0.advanced(by: offset + 1), length: length, encoding: .utf8, freeWhenDone: false)
String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false)
}) else {
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string"))
}
base += length
return string
case FormatID.string8.rawValue:
let length = Int(data[offset + 1])
let length = Int(data[base + 1])
base += 2
guard let string = data.withUnsafeMutableBytes({
String(bytesNoCopy: $0.advanced(by: offset + 2), length: length, encoding: .utf8, freeWhenDone: false)
String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false)
}) else {
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string"))
}
base += length
return string
case FormatID.string16.rawValue:
let length = Int(data.bigEndianInteger(at: offset) as UInt16)
let length = Int(data.bigEndianInteger(at: base + 1) as UInt16)
base += 3
guard let string = data.withUnsafeMutableBytes({
String(bytesNoCopy: $0.advanced(by: offset + 3), length: length, encoding: .utf8, freeWhenDone: false)
String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false)
}) else {
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string"))
}
base += length
return string
case FormatID.string32.rawValue:
let length = Int(data.bigEndianInteger(at: offset) as UInt16)
let length = Int(data.bigEndianInteger(at: base + 1) as UInt16)
base += 5
guard let string = data.withUnsafeMutableBytes({
String(bytesNoCopy: $0.advanced(by: offset + 5), length: length, encoding: .utf8, freeWhenDone: false)
String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false)
}) else {
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string"))
}
base += length
return string
default:
throw DecodingError.typeMismatch(String.self, .init(codingPath: [], debugDescription: "Wrong string format: \(data[offset])"))
throw DecodingError.typeMismatch(String.self, .init(codingPath: [], debugDescription: "Wrong string format: \(data[base])"))
}
}
}

View file

@ -16,11 +16,13 @@ class MsgPackTests: XCTestCase {
// }
var encoder: MsgPack.Encoder!
var decoder: MsgPack.Decoder!
override func setUp() {
super.setUp()
encoder = Encoder()
decoder = Decoder()
}
@ -108,9 +110,18 @@ class MsgPackTests: XCTestCase {
}
}
struct Simple: Codable {
let a: Bool
let b: Bool
}
func roundtrip<T: Codable>(value: T) throws -> T {
return try decoder.decode(T.self, from: encoder.encode(value))
}
func testExample() {
do {
print(try encoder.encode([[Int8(-3)], [5], [8]]))
print("roundtrip:", try roundtrip(value: Simple(a: false, b: false)))
} catch {
print(error)
}

View file

@ -33,10 +33,7 @@ let graph = Graph(
]
)
let encodedGraph = try encoder.encode(graph)
for byte in encodedGraph {
print(String(byte, radix:16))
}
try encoder.encode(graph)
let decoder = Decoder()
@ -44,4 +41,13 @@ func roundtrip<T: Codable>(value: T) throws -> T {
return try decoder.decode(T.self, from: encoder.encode(value))
}
try roundtrip(value: "Hello")
struct Simple: Codable {
let a: Bool
let b: Bool
let c: Bool?
let d: Bool?
let e: Bool?
}
try roundtrip(value: -56.4)
try roundtrip(value: Simple(a: true, b: false, c: true, d: false, e: nil))