mirror of
https://github.com/Dev1an/MsgPack.git
synced 2026-03-25 08:45:55 +00:00
Decode dictionaries containing bools
This commit is contained in:
parent
aca03f1952
commit
b6fbe7b393
4 changed files with 161 additions and 63 deletions
|
|
@ -24,13 +24,15 @@ class IntermediateDecoder: Swift.Decoder {
|
|||
|
||||
var storage: Data
|
||||
var offset = 0
|
||||
|
||||
var dictionary = [String:(FormatID, Int)]()
|
||||
|
||||
init(with data: Data) {
|
||||
storage = data
|
||||
}
|
||||
|
||||
func container<Key>(keyedBy type: Key.Type) throws -> KeyedDecodingContainer<Key> where Key : CodingKey {
|
||||
fatalError()
|
||||
return KeyedDecodingContainer(try MsgPckKeyedDecodingContainer<Key>(referringTo: self, at: offset, with: []))
|
||||
}
|
||||
|
||||
func unkeyedContainer() throws -> UnkeyedDecodingContainer {
|
||||
|
|
@ -38,23 +40,20 @@ class IntermediateDecoder: Swift.Decoder {
|
|||
}
|
||||
|
||||
func singleValueContainer() throws -> SingleValueDecodingContainer {
|
||||
return MsgPckSingleValueDecodingContainer(refferingTo: self)
|
||||
return MsgPckSingleValueDecodingContainer(decoder: self, base: 0, codingPath: [])
|
||||
}
|
||||
}
|
||||
|
||||
struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer {
|
||||
var codingPath = [CodingKey]()
|
||||
var decoder: IntermediateDecoder
|
||||
var base = 0
|
||||
let decoder: IntermediateDecoder
|
||||
var base: Int
|
||||
|
||||
var codingPath: [CodingKey]
|
||||
|
||||
enum Error: Swift.Error {
|
||||
case invalidFormat(UInt8)
|
||||
}
|
||||
|
||||
init(refferingTo decoder: IntermediateDecoder) {
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
func decodeNil() -> Bool {
|
||||
return decoder.storage[base] == FormatID.nil.rawValue
|
||||
}
|
||||
|
|
@ -147,8 +146,8 @@ struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer {
|
|||
return .init(bitPattern: bitPattern)
|
||||
}
|
||||
|
||||
func decode(_ type: String.Type) throws -> String {
|
||||
return try Format.string(from: &decoder.storage)
|
||||
mutating func decode(_ type: String.Type) throws -> String {
|
||||
return try Format.string(from: &decoder.storage, base: &base)
|
||||
}
|
||||
|
||||
func decode<T>(_ type: T.Type) throws -> T where T : Decodable {
|
||||
|
|
@ -159,18 +158,92 @@ struct MsgPckSingleValueDecodingContainer: SingleValueDecodingContainer {
|
|||
struct MsgPckKeyedDecodingContainer<K: CodingKey>: KeyedDecodingContainerProtocol {
|
||||
var codingPath: [CodingKey]
|
||||
|
||||
var allKeys: [K]
|
||||
let decoder: IntermediateDecoder
|
||||
var base = 0
|
||||
|
||||
var allKeys = [K]()
|
||||
|
||||
init(referringTo decoder: IntermediateDecoder, at base: Int, with codingPath: [CodingKey]) throws {
|
||||
self.codingPath = codingPath
|
||||
self.decoder = decoder
|
||||
self.base = base
|
||||
|
||||
let elementCount: Int
|
||||
switch decoder.storage[base] {
|
||||
case FormatID.fixMapRange:
|
||||
elementCount = Int(decoder.storage[base] - FormatID.fixMap.rawValue)
|
||||
self.base += 1
|
||||
case FormatID.map16.rawValue:
|
||||
elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt16)
|
||||
self.base += 3
|
||||
case FormatID.map32.rawValue:
|
||||
elementCount = Int(decoder.storage.bigEndianInteger(at: base+1) as UInt32)
|
||||
self.base += 5
|
||||
default:
|
||||
throw DecodingError.typeMismatch(Dictionary<String,Any>.self, .init(codingPath: codingPath, debugDescription: "Expected a MsgPack map format, but found 0x\(String(decoder.storage[base], radix: 16))"))
|
||||
}
|
||||
for cursor in 0 ..< elementCount {
|
||||
let key = try Format.string(from: &decoder.storage, base: &self.base)
|
||||
|
||||
let valueFormat: FormatID
|
||||
if let valueFormatLookup = FormatID(rawValue: decoder.storage[self.base]) {
|
||||
valueFormat = valueFormatLookup
|
||||
} else {
|
||||
switch decoder.storage[self.base] {
|
||||
case FormatID.fixMapRange:
|
||||
valueFormat = .fixMap
|
||||
case FormatID.fixStringRange:
|
||||
valueFormat = .fixString
|
||||
case FormatID.positiveInt7Range:
|
||||
valueFormat = .positiveInt7
|
||||
case FormatID.negativeInt5Range:
|
||||
valueFormat = .negativeInt5
|
||||
default:
|
||||
throw DecodingError.dataCorrupted(.init(codingPath: codingPath, debugDescription: "Unknown format: 0x\(String(decoder.storage[self.base], radix: 16))"))
|
||||
}
|
||||
}
|
||||
|
||||
let length: Int
|
||||
switch valueFormat {
|
||||
case .positiveInt7, .negativeInt5, .fixArray, .fixMap, .fixString:
|
||||
length = Int(decoder.storage[self.base] - valueFormat.rawValue)
|
||||
self.base += length + 1
|
||||
case .uInt8, .int8, .string8:
|
||||
length = Int(decoder.storage[self.base+1])
|
||||
self.base += length + 2
|
||||
case .uInt16, .int16, .string16, .array16, .map16:
|
||||
length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt16)
|
||||
self.base += length + 3
|
||||
case .uInt32, .int32, .string32, .array32, .map32, .float32:
|
||||
length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt32)
|
||||
self.base += length + 5
|
||||
case .uInt64, .int64, .float64:
|
||||
length = Int(decoder.storage.bigEndianInteger(at: self.base + 1) as UInt64)
|
||||
self.base += length + 9
|
||||
case .nil, .false, .true:
|
||||
length = 0
|
||||
self.base += 1
|
||||
}
|
||||
decoder.dictionary[key] = (valueFormat, length)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(_ key: K) -> Bool {
|
||||
fatalError("not implemented")
|
||||
return decoder.dictionary[key.stringValue] != nil
|
||||
}
|
||||
|
||||
func decodeNil(forKey key: K) throws -> Bool {
|
||||
fatalError("not implemented")
|
||||
return decoder.dictionary[key.stringValue]?.0 == FormatID.nil
|
||||
}
|
||||
|
||||
func decode(_ type: Bool.Type, forKey key: K) throws -> Bool {
|
||||
fatalError("not implemented")
|
||||
let format = decoder.dictionary[key.stringValue]!.0
|
||||
switch format {
|
||||
case .true, .false:
|
||||
return format == .true
|
||||
default:
|
||||
throw DecodingError.typeMismatch(type, .init(codingPath: codingPath, debugDescription: "Expected bool but found \(format)"))
|
||||
}
|
||||
}
|
||||
|
||||
func decode(_ type: Int.Type, forKey key: K) throws -> Int {
|
||||
|
|
@ -246,5 +319,4 @@ struct MsgPckKeyedDecodingContainer<K: CodingKey>: KeyedDecodingContainerProtoco
|
|||
}
|
||||
|
||||
typealias Key = K
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,22 +52,19 @@ public enum FormatID: UInt8 {
|
|||
case `false` = 0xC2
|
||||
case `true`
|
||||
|
||||
case bin8
|
||||
case bin16
|
||||
case bin32
|
||||
// case bin8
|
||||
// case bin16
|
||||
// case bin32
|
||||
//
|
||||
// case ext8
|
||||
// case ext16
|
||||
// case ext32
|
||||
|
||||
case ext8
|
||||
case ext16
|
||||
case ext32
|
||||
|
||||
case float32
|
||||
case float32 = 0xCA
|
||||
case float64
|
||||
|
||||
case firstPositiveInt7 = 0b00000000
|
||||
case lastPositiveInt7 = 0b01111111
|
||||
|
||||
case firstNegativeInt5 = 0b11100000
|
||||
case lastNegativeInt5 = 0b11111111
|
||||
case positiveInt7 = 0b00000000
|
||||
case negativeInt5 = 0b11100000
|
||||
|
||||
case uInt8 = 0xCC
|
||||
case uInt16
|
||||
|
|
@ -79,27 +76,31 @@ public enum FormatID: UInt8 {
|
|||
case int32
|
||||
case int64
|
||||
|
||||
case fixExt1
|
||||
case fixExt2
|
||||
case fixExt4
|
||||
case fixExt8
|
||||
case fixExt16
|
||||
// case fixExt1
|
||||
// case fixExt2
|
||||
// case fixExt4
|
||||
// case fixExt8
|
||||
// case fixExt16
|
||||
|
||||
case fixStringStart = 0b10100000
|
||||
case fixStringEnd = 0b10111111
|
||||
case fixString = 0b10100000
|
||||
case string8 = 0xD9
|
||||
case string16
|
||||
case string32
|
||||
|
||||
case fixArrayStart = 0b10010000
|
||||
case fixArrayEnd = 0b10011111
|
||||
case fixArray = 0b10010000
|
||||
case array16 = 0xDC
|
||||
case array32
|
||||
|
||||
case fixMapStart = 0b10000000
|
||||
case fixMapEnd = 0b10001111
|
||||
case fixMap = 0b10000000
|
||||
case map16 = 0xDE
|
||||
case map32
|
||||
|
||||
static let positiveInt7Range = FormatID.positiveInt7.rawValue ..< FormatID.fixMap.rawValue
|
||||
static let negativeInt5Range = FormatID.negativeInt5.rawValue ..< 0b11111111
|
||||
|
||||
static let fixMapRange = FormatID.fixMap.rawValue ..< FormatID.fixArray.rawValue
|
||||
static let fixArrayRange = FormatID.fixArray.rawValue ..< FormatID.fixString.rawValue
|
||||
static let fixStringRange = FormatID.fixString.rawValue ..< FormatID.nil.rawValue
|
||||
}
|
||||
|
||||
extension Format {
|
||||
|
|
@ -116,9 +117,9 @@ extension Format {
|
|||
|
||||
// MARK: Small integers (< 8 bit)
|
||||
case .positiveInt7(let value):
|
||||
data.append(value | FormatID.firstPositiveInt7.rawValue)
|
||||
data.append(value | FormatID.positiveInt7.rawValue)
|
||||
case .negativeInt5(let value):
|
||||
data.append(value | FormatID.firstNegativeInt5.rawValue)
|
||||
data.append(value | FormatID.negativeInt5.rawValue)
|
||||
|
||||
// MARK: Unsigned integers
|
||||
case .uInt8(let value):
|
||||
|
|
@ -179,7 +180,7 @@ extension Format {
|
|||
// MARK: Strings
|
||||
case .fixString(let utf8Data):
|
||||
precondition(utf8Data.count < 32, "fix strings cannot contain more than 31 bytes")
|
||||
data.append( UInt8(utf8Data.count) | FormatID.fixStringStart.rawValue)
|
||||
data.append( UInt8(utf8Data.count) | FormatID.fixString.rawValue)
|
||||
data.append(utf8Data)
|
||||
case .string8(let utf8Data):
|
||||
data.append(contentsOf: [FormatID.string8.rawValue, UInt8(utf8Data.count)])
|
||||
|
|
@ -200,7 +201,7 @@ extension Format {
|
|||
// MARK: Arrays
|
||||
case .fixArray(let array):
|
||||
precondition(array.count < 16, "fix arrays cannot contain more than 15 elements")
|
||||
data.append( UInt8(array.count) | FormatID.fixArrayStart.rawValue)
|
||||
data.append( UInt8(array.count) | FormatID.fixArray.rawValue)
|
||||
for element in array {
|
||||
element.appendTo(data: &data)
|
||||
}
|
||||
|
|
@ -224,7 +225,7 @@ extension Format {
|
|||
// MARK: Maps
|
||||
case .fixMap(let pairs):
|
||||
precondition(pairs.count < 16, "fix maps cannot contain more than 15 key-value pairs")
|
||||
data.append( UInt8(pairs.count) | FormatID.fixMapStart.rawValue)
|
||||
data.append( UInt8(pairs.count) | FormatID.fixMap.rawValue)
|
||||
for (key, value) in pairs {
|
||||
key.appendTo(data: &data)
|
||||
value.appendTo(data: &data)
|
||||
|
|
@ -329,42 +330,50 @@ extension Format {
|
|||
}
|
||||
|
||||
extension Format {
|
||||
static func string(from data: inout Data, offset: Int = 0) throws -> String {
|
||||
switch data[offset] {
|
||||
case FormatID.fixStringStart.rawValue ... FormatID.fixStringEnd.rawValue:
|
||||
let length = Int(data[offset] & 0b00011111)
|
||||
static func string(from data: inout Data, base: inout Int) throws -> String {
|
||||
switch data[base] {
|
||||
case FormatID.fixStringRange:
|
||||
let length = Int(data[base] & 0b00011111)
|
||||
base += 1
|
||||
guard let string = data.withUnsafeMutableBytes({
|
||||
String(bytesNoCopy: $0.advanced(by: offset + 1), length: length, encoding: .utf8, freeWhenDone: false)
|
||||
String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false)
|
||||
}) else {
|
||||
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string"))
|
||||
}
|
||||
base += length
|
||||
return string
|
||||
case FormatID.string8.rawValue:
|
||||
let length = Int(data[offset + 1])
|
||||
let length = Int(data[base + 1])
|
||||
base += 2
|
||||
guard let string = data.withUnsafeMutableBytes({
|
||||
String(bytesNoCopy: $0.advanced(by: offset + 2), length: length, encoding: .utf8, freeWhenDone: false)
|
||||
String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false)
|
||||
}) else {
|
||||
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string"))
|
||||
}
|
||||
base += length
|
||||
return string
|
||||
case FormatID.string16.rawValue:
|
||||
let length = Int(data.bigEndianInteger(at: offset) as UInt16)
|
||||
let length = Int(data.bigEndianInteger(at: base + 1) as UInt16)
|
||||
base += 3
|
||||
guard let string = data.withUnsafeMutableBytes({
|
||||
String(bytesNoCopy: $0.advanced(by: offset + 3), length: length, encoding: .utf8, freeWhenDone: false)
|
||||
String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false)
|
||||
}) else {
|
||||
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string"))
|
||||
}
|
||||
base += length
|
||||
return string
|
||||
case FormatID.string32.rawValue:
|
||||
let length = Int(data.bigEndianInteger(at: offset) as UInt16)
|
||||
let length = Int(data.bigEndianInteger(at: base + 1) as UInt16)
|
||||
base += 5
|
||||
guard let string = data.withUnsafeMutableBytes({
|
||||
String(bytesNoCopy: $0.advanced(by: offset + 5), length: length, encoding: .utf8, freeWhenDone: false)
|
||||
String(bytesNoCopy: $0.advanced(by: base), length: length, encoding: .utf8, freeWhenDone: false)
|
||||
}) else {
|
||||
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "not a valid string"))
|
||||
}
|
||||
base += length
|
||||
return string
|
||||
default:
|
||||
throw DecodingError.typeMismatch(String.self, .init(codingPath: [], debugDescription: "Wrong string format: \(data[offset])"))
|
||||
throw DecodingError.typeMismatch(String.self, .init(codingPath: [], debugDescription: "Wrong string format: \(data[base])"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,11 +16,13 @@ class MsgPackTests: XCTestCase {
|
|||
// }
|
||||
|
||||
var encoder: MsgPack.Encoder!
|
||||
var decoder: MsgPack.Decoder!
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
|
||||
encoder = Encoder()
|
||||
decoder = Decoder()
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -108,9 +110,18 @@ class MsgPackTests: XCTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
struct Simple: Codable {
|
||||
let a: Bool
|
||||
let b: Bool
|
||||
}
|
||||
|
||||
func roundtrip<T: Codable>(value: T) throws -> T {
|
||||
return try decoder.decode(T.self, from: encoder.encode(value))
|
||||
}
|
||||
|
||||
func testExample() {
|
||||
do {
|
||||
print(try encoder.encode([[Int8(-3)], [5], [8]]))
|
||||
print("roundtrip:", try roundtrip(value: Simple(a: false, b: false)))
|
||||
} catch {
|
||||
print(error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,10 +33,7 @@ let graph = Graph(
|
|||
]
|
||||
)
|
||||
|
||||
let encodedGraph = try encoder.encode(graph)
|
||||
for byte in encodedGraph {
|
||||
print(String(byte, radix:16))
|
||||
}
|
||||
try encoder.encode(graph)
|
||||
|
||||
let decoder = Decoder()
|
||||
|
||||
|
|
@ -44,4 +41,13 @@ func roundtrip<T: Codable>(value: T) throws -> T {
|
|||
return try decoder.decode(T.self, from: encoder.encode(value))
|
||||
}
|
||||
|
||||
try roundtrip(value: "Hello")
|
||||
struct Simple: Codable {
|
||||
let a: Bool
|
||||
let b: Bool
|
||||
let c: Bool?
|
||||
let d: Bool?
|
||||
let e: Bool?
|
||||
}
|
||||
|
||||
try roundtrip(value: -56.4)
|
||||
try roundtrip(value: Simple(a: true, b: false, c: true, d: false, e: nil))
|
||||
|
|
|
|||
Loading…
Reference in a new issue