12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058 |
- # Protocol Buffers - Google's data interchange format
- # Copyright 2008 Google Inc. All rights reserved.
- # https://developers.google.com/protocol-buffers/
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are
- # met:
- #
- # * Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # * Redistributions in binary form must reproduce the above
- # copyright notice, this list of conditions and the following disclaimer
- # in the documentation and/or other materials provided with the
- # distribution.
- # * Neither the name of Google Inc. nor the names of its
- # contributors may be used to endorse or promote products derived from
- # this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- """Code for decoding protocol buffer primitives.
- This code is very similar to encoder.py -- read the docs for that module first.
- A "decoder" is a function with the signature:
- Decode(buffer, pos, end, message, field_dict)
- The arguments are:
- buffer: The string containing the encoded message.
- pos: The current position in the string.
- end: The position in the string where the current message ends. May be
- less than len(buffer) if we're reading a sub-message.
- message: The message object into which we're parsing.
- field_dict: message._fields (avoids a hashtable lookup).
- The decoder reads the field and stores it into field_dict, returning the new
- buffer position. A decoder for a repeated field may proactively decode all of
- the elements of that field, if they appear consecutively.
- Note that decoders may throw any of the following:
- IndexError: Indicates a truncated message.
- struct.error: Unpacking of a fixed-width field failed.
- message.DecodeError: Other errors.
- Decoders are expected to raise an exception if they are called with pos > end.
- This allows callers to be lax about bounds checking: it's fineto read past
- "end" as long as you are sure that someone else will notice and throw an
- exception later on.
- Something up the call stack is expected to catch IndexError and struct.error
- and convert them to message.DecodeError.
- Decoders are constructed using decoder constructors with the signature:
- MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
- The arguments are:
- field_number: The field number of the field we want to decode.
- is_repeated: Is the field a repeated field? (bool)
- is_packed: Is the field a packed field? (bool)
- key: The key to use when looking up the field within field_dict.
- (This is actually the FieldDescriptor but nothing in this
- file should depend on that.)
- new_default: A function which takes a message object as a parameter and
- returns a new instance of the default value for this field.
- (This is called for repeated fields and sub-messages, when an
- instance does not already exist.)
- As with encoders, we define a decoder constructor for every type of field.
- Then, for every field of every message class we construct an actual decoder.
- That decoder goes into a dict indexed by tag, so when we decode a message
- we repeatedly read a tag, look up the corresponding decoder, and invoke it.
- """
- __author__ = 'kenton@google.com (Kenton Varda)'
- import struct
- import sys
- import six
- _UCS2_MAXUNICODE = 65535
- if six.PY3:
- long = int
- else:
- import re # pylint: disable=g-import-not-at-top
- _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]'))
- from google.protobuf.internal import containers
- from google.protobuf.internal import encoder
- from google.protobuf.internal import wire_format
- from google.protobuf import message
- # This will overflow and thus become IEEE-754 "infinity". We would use
- # "float('inf')" but it doesn't work on Windows pre-Python-2.6.
- _POS_INF = 1e10000
- _NEG_INF = -_POS_INF
- _NAN = _POS_INF * 0
- # This is not for optimization, but rather to avoid conflicts with local
- # variables named "message".
- _DecodeError = message.DecodeError
- def _VarintDecoder(mask, result_type):
- """Return an encoder for a basic varint value (does not include tag).
- Decoded values will be bitwise-anded with the given mask before being
- returned, e.g. to limit them to 32 bits. The returned decoder does not
- take the usual "end" parameter -- the caller is expected to do bounds checking
- after the fact (often the caller can defer such checking until later). The
- decoder returns a (value, new_pos) pair.
- """
- def DecodeVarint(buffer, pos):
- result = 0
- shift = 0
- while 1:
- b = six.indexbytes(buffer, pos)
- result |= ((b & 0x7f) << shift)
- pos += 1
- if not (b & 0x80):
- result &= mask
- result = result_type(result)
- return (result, pos)
- shift += 7
- if shift >= 64:
- raise _DecodeError('Too many bytes when decoding varint.')
- return DecodeVarint
- def _SignedVarintDecoder(bits, result_type):
- """Like _VarintDecoder() but decodes signed values."""
- signbit = 1 << (bits - 1)
- mask = (1 << bits) - 1
- def DecodeVarint(buffer, pos):
- result = 0
- shift = 0
- while 1:
- b = six.indexbytes(buffer, pos)
- result |= ((b & 0x7f) << shift)
- pos += 1
- if not (b & 0x80):
- result &= mask
- result = (result ^ signbit) - signbit
- result = result_type(result)
- return (result, pos)
- shift += 7
- if shift >= 64:
- raise _DecodeError('Too many bytes when decoding varint.')
- return DecodeVarint
- # We force 32-bit values to int and 64-bit values to long to make
- # alternate implementations where the distinction is more significant
- # (e.g. the C++ implementation) simpler.
- _DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
- _DecodeSignedVarint = _SignedVarintDecoder(64, long)
- # Use these versions for values which must be limited to 32 bits.
- _DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
- _DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
- def ReadTag(buffer, pos):
- """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
- We return the raw bytes of the tag rather than decoding them. The raw
- bytes can then be used to look up the proper decoder. This effectively allows
- us to trade some work that would be done in pure-python (decoding a varint)
- for work that is done in C (searching for a byte string in a hash table).
- In a low-level language it would be much cheaper to decode the varint and
- use that, but not in Python.
- Args:
- buffer: memoryview object of the encoded bytes
- pos: int of the current position to start from
- Returns:
- Tuple[bytes, int] of the tag data and new position.
- """
- start = pos
- while six.indexbytes(buffer, pos) & 0x80:
- pos += 1
- pos += 1
- tag_bytes = buffer[start:pos].tobytes()
- return tag_bytes, pos
- # --------------------------------------------------------------------
- def _SimpleDecoder(wire_type, decode_value):
- """Return a constructor for a decoder for fields of a particular type.
- Args:
- wire_type: The field's wire type.
- decode_value: A function which decodes an individual value, e.g.
- _DecodeVarint()
- """
- def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
- clear_if_default=False):
- if is_packed:
- local_DecodeVarint = _DecodeVarint
- def DecodePackedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- (endpoint, pos) = local_DecodeVarint(buffer, pos)
- endpoint += pos
- if endpoint > end:
- raise _DecodeError('Truncated message.')
- while pos < endpoint:
- (element, pos) = decode_value(buffer, pos)
- value.append(element)
- if pos > endpoint:
- del value[-1] # Discard corrupt value.
- raise _DecodeError('Packed element was truncated.')
- return pos
- return DecodePackedField
- elif is_repeated:
- tag_bytes = encoder.TagBytes(field_number, wire_type)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- (element, new_pos) = decode_value(buffer, pos)
- value.append(element)
- # Predict that the next tag is another copy of the same repeated
- # field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
- # Prediction failed. Return.
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- (new_value, pos) = decode_value(buffer, pos)
- if pos > end:
- raise _DecodeError('Truncated message.')
- if clear_if_default and not new_value:
- field_dict.pop(key, None)
- else:
- field_dict[key] = new_value
- return pos
- return DecodeField
- return SpecificDecoder
- def _ModifiedDecoder(wire_type, decode_value, modify_value):
- """Like SimpleDecoder but additionally invokes modify_value on every value
- before storing it. Usually modify_value is ZigZagDecode.
- """
- # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
- # not enough to make a significant difference.
- def InnerDecode(buffer, pos):
- (result, new_pos) = decode_value(buffer, pos)
- return (modify_value(result), new_pos)
- return _SimpleDecoder(wire_type, InnerDecode)
- def _StructPackDecoder(wire_type, format):
- """Return a constructor for a decoder for a fixed-width field.
- Args:
- wire_type: The field's wire type.
- format: The format string to pass to struct.unpack().
- """
- value_size = struct.calcsize(format)
- local_unpack = struct.unpack
- # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
- # not enough to make a significant difference.
- # Note that we expect someone up-stack to catch struct.error and convert
- # it to _DecodeError -- this way we don't have to set up exception-
- # handling blocks every time we parse one value.
- def InnerDecode(buffer, pos):
- new_pos = pos + value_size
- result = local_unpack(format, buffer[pos:new_pos])[0]
- return (result, new_pos)
- return _SimpleDecoder(wire_type, InnerDecode)
- def _FloatDecoder():
- """Returns a decoder for a float field.
- This code works around a bug in struct.unpack for non-finite 32-bit
- floating-point values.
- """
- local_unpack = struct.unpack
- def InnerDecode(buffer, pos):
- """Decode serialized float to a float and new position.
- Args:
- buffer: memoryview of the serialized bytes
- pos: int, position in the memory view to start at.
- Returns:
- Tuple[float, int] of the deserialized float value and new position
- in the serialized data.
- """
- # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
- # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
- new_pos = pos + 4
- float_bytes = buffer[pos:new_pos].tobytes()
- # If this value has all its exponent bits set, then it's non-finite.
- # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
- # To avoid that, we parse it specially.
- if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
- # If at least one significand bit is set...
- if float_bytes[0:3] != b'\x00\x00\x80':
- return (_NAN, new_pos)
- # If sign bit is set...
- if float_bytes[3:4] == b'\xFF':
- return (_NEG_INF, new_pos)
- return (_POS_INF, new_pos)
- # Note that we expect someone up-stack to catch struct.error and convert
- # it to _DecodeError -- this way we don't have to set up exception-
- # handling blocks every time we parse one value.
- result = local_unpack('<f', float_bytes)[0]
- return (result, new_pos)
- return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
- def _DoubleDecoder():
- """Returns a decoder for a double field.
- This code works around a bug in struct.unpack for not-a-number.
- """
- local_unpack = struct.unpack
- def InnerDecode(buffer, pos):
- """Decode serialized double to a double and new position.
- Args:
- buffer: memoryview of the serialized bytes.
- pos: int, position in the memory view to start at.
- Returns:
- Tuple[float, int] of the decoded double value and new position
- in the serialized data.
- """
- # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
- # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
- new_pos = pos + 8
- double_bytes = buffer[pos:new_pos].tobytes()
- # If this value has all its exponent bits set and at least one significand
- # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
- # as inf or -inf. To avoid that, we treat it specially.
- if ((double_bytes[7:8] in b'\x7F\xFF')
- and (double_bytes[6:7] >= b'\xF0')
- and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
- return (_NAN, new_pos)
- # Note that we expect someone up-stack to catch struct.error and convert
- # it to _DecodeError -- this way we don't have to set up exception-
- # handling blocks every time we parse one value.
- result = local_unpack('<d', double_bytes)[0]
- return (result, new_pos)
- return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
- def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
- clear_if_default=False):
- """Returns a decoder for enum field."""
- enum_type = key.enum_type
- if is_packed:
- local_DecodeVarint = _DecodeVarint
- def DecodePackedField(buffer, pos, end, message, field_dict):
- """Decode serialized packed enum to its value and a new position.
- Args:
- buffer: memoryview of the serialized bytes.
- pos: int, position in the memory view to start at.
- end: int, end position of serialized data
- message: Message object to store unknown fields in
- field_dict: Map[Descriptor, Any] to store decoded values in.
- Returns:
- int, new position in serialized data.
- """
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- (endpoint, pos) = local_DecodeVarint(buffer, pos)
- endpoint += pos
- if endpoint > end:
- raise _DecodeError('Truncated message.')
- while pos < endpoint:
- value_start_pos = pos
- (element, pos) = _DecodeSignedVarint32(buffer, pos)
- # pylint: disable=protected-access
- if element in enum_type.values_by_number:
- value.append(element)
- else:
- if not message._unknown_fields:
- message._unknown_fields = []
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_VARINT)
- message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos].tobytes()))
- if message._unknown_field_set is None:
- message._unknown_field_set = containers.UnknownFieldSet()
- message._unknown_field_set._add(
- field_number, wire_format.WIRETYPE_VARINT, element)
- # pylint: enable=protected-access
- if pos > endpoint:
- if element in enum_type.values_by_number:
- del value[-1] # Discard corrupt value.
- else:
- del message._unknown_fields[-1]
- # pylint: disable=protected-access
- del message._unknown_field_set._values[-1]
- # pylint: enable=protected-access
- raise _DecodeError('Packed element was truncated.')
- return pos
- return DecodePackedField
- elif is_repeated:
- tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- """Decode serialized repeated enum to its value and a new position.
- Args:
- buffer: memoryview of the serialized bytes.
- pos: int, position in the memory view to start at.
- end: int, end position of serialized data
- message: Message object to store unknown fields in
- field_dict: Map[Descriptor, Any] to store decoded values in.
- Returns:
- int, new position in serialized data.
- """
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
- # pylint: disable=protected-access
- if element in enum_type.values_by_number:
- value.append(element)
- else:
- if not message._unknown_fields:
- message._unknown_fields = []
- message._unknown_fields.append(
- (tag_bytes, buffer[pos:new_pos].tobytes()))
- if message._unknown_field_set is None:
- message._unknown_field_set = containers.UnknownFieldSet()
- message._unknown_field_set._add(
- field_number, wire_format.WIRETYPE_VARINT, element)
- # pylint: enable=protected-access
- # Predict that the next tag is another copy of the same repeated
- # field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
- # Prediction failed. Return.
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- """Decode serialized repeated enum to its value and a new position.
- Args:
- buffer: memoryview of the serialized bytes.
- pos: int, position in the memory view to start at.
- end: int, end position of serialized data
- message: Message object to store unknown fields in
- field_dict: Map[Descriptor, Any] to store decoded values in.
- Returns:
- int, new position in serialized data.
- """
- value_start_pos = pos
- (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
- if pos > end:
- raise _DecodeError('Truncated message.')
- if clear_if_default and not enum_value:
- field_dict.pop(key, None)
- return pos
- # pylint: disable=protected-access
- if enum_value in enum_type.values_by_number:
- field_dict[key] = enum_value
- else:
- if not message._unknown_fields:
- message._unknown_fields = []
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_VARINT)
- message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos].tobytes()))
- if message._unknown_field_set is None:
- message._unknown_field_set = containers.UnknownFieldSet()
- message._unknown_field_set._add(
- field_number, wire_format.WIRETYPE_VARINT, enum_value)
- # pylint: enable=protected-access
- return pos
- return DecodeField
- # --------------------------------------------------------------------
- Int32Decoder = _SimpleDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
- Int64Decoder = _SimpleDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
- UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
- UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
- SInt32Decoder = _ModifiedDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
- SInt64Decoder = _ModifiedDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
- # Note that Python conveniently guarantees that when using the '<' prefix on
- # formats, they will also have the same size across all platforms (as opposed
- # to without the prefix, where their sizes depend on the C compiler's basic
- # type sizes).
- Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
- Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
- SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
- SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
- FloatDecoder = _FloatDecoder()
- DoubleDecoder = _DoubleDecoder()
- BoolDecoder = _ModifiedDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
- def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
- is_strict_utf8=False, clear_if_default=False):
- """Returns a decoder for a string field."""
- local_DecodeVarint = _DecodeVarint
- local_unicode = six.text_type
- def _ConvertToUnicode(memview):
- """Convert byte to unicode."""
- byte_str = memview.tobytes()
- try:
- value = local_unicode(byte_str, 'utf-8')
- except UnicodeDecodeError as e:
- # add more information to the error message and re-raise it.
- e.reason = '%s in field: %s' % (e, key.full_name)
- raise
- if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE:
- # Only do the check for python2 ucs4 when is_strict_utf8 enabled
- if _SURROGATE_PATTERN.search(value):
- reason = ('String field %s contains invalid UTF-8 data when parsing'
- 'a protocol buffer: surrogates not allowed. Use'
- 'the bytes type if you intend to send raw bytes.') % (
- key.full_name)
- raise message.DecodeError(reason)
- return value
- assert not is_packed
- if is_repeated:
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_LENGTH_DELIMITED)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated string.')
- value.append(_ConvertToUnicode(buffer[pos:new_pos]))
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated string.')
- if clear_if_default and not size:
- field_dict.pop(key, None)
- else:
- field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
- return new_pos
- return DecodeField
- def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
- clear_if_default=False):
- """Returns a decoder for a bytes field."""
- local_DecodeVarint = _DecodeVarint
- assert not is_packed
- if is_repeated:
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_LENGTH_DELIMITED)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated string.')
- value.append(buffer[pos:new_pos].tobytes())
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated string.')
- if clear_if_default and not size:
- field_dict.pop(key, None)
- else:
- field_dict[key] = buffer[pos:new_pos].tobytes()
- return new_pos
- return DecodeField
- def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
- """Returns a decoder for a group field."""
- end_tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_END_GROUP)
- end_tag_len = len(end_tag_bytes)
- assert not is_packed
- if is_repeated:
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_START_GROUP)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- # Read sub-message.
- pos = value.add()._InternalParse(buffer, pos, end)
- # Read end tag.
- new_pos = pos+end_tag_len
- if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
- raise _DecodeError('Missing group end tag.')
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- # Read sub-message.
- pos = value._InternalParse(buffer, pos, end)
- # Read end tag.
- new_pos = pos+end_tag_len
- if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
- raise _DecodeError('Missing group end tag.')
- return new_pos
- return DecodeField
- def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
- """Returns a decoder for a message field."""
- local_DecodeVarint = _DecodeVarint
- assert not is_packed
- if is_repeated:
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_LENGTH_DELIMITED)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- # Read length.
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- # Read sub-message.
- if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
- # The only reason _InternalParse would return early is if it
- # encountered an end-group tag.
- raise _DecodeError('Unexpected end-group tag.')
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- # Read length.
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- # Read sub-message.
- if value._InternalParse(buffer, pos, new_pos) != new_pos:
- # The only reason _InternalParse would return early is if it encountered
- # an end-group tag.
- raise _DecodeError('Unexpected end-group tag.')
- return new_pos
- return DecodeField
- # --------------------------------------------------------------------
- MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
- def MessageSetItemDecoder(descriptor):
- """Returns a decoder for a MessageSet item.
- The parameter is the message Descriptor.
- The message set message looks like this:
- message MessageSet {
- repeated group Item = 1 {
- required int32 type_id = 2;
- required string message = 3;
- }
- }
- """
- type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
- message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
- item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
- local_ReadTag = ReadTag
- local_DecodeVarint = _DecodeVarint
- local_SkipField = SkipField
- def DecodeItem(buffer, pos, end, message, field_dict):
- """Decode serialized message set to its value and new position.
- Args:
- buffer: memoryview of the serialized bytes.
- pos: int, position in the memory view to start at.
- end: int, end position of serialized data
- message: Message object to store unknown fields in
- field_dict: Map[Descriptor, Any] to store decoded values in.
- Returns:
- int, new position in serialized data.
- """
- message_set_item_start = pos
- type_id = -1
- message_start = -1
- message_end = -1
- # Technically, type_id and message can appear in any order, so we need
- # a little loop here.
- while 1:
- (tag_bytes, pos) = local_ReadTag(buffer, pos)
- if tag_bytes == type_id_tag_bytes:
- (type_id, pos) = local_DecodeVarint(buffer, pos)
- elif tag_bytes == message_tag_bytes:
- (size, message_start) = local_DecodeVarint(buffer, pos)
- pos = message_end = message_start + size
- elif tag_bytes == item_end_tag_bytes:
- break
- else:
- pos = SkipField(buffer, pos, end, tag_bytes)
- if pos == -1:
- raise _DecodeError('Missing group end tag.')
- if pos > end:
- raise _DecodeError('Truncated message.')
- if type_id == -1:
- raise _DecodeError('MessageSet item missing type_id.')
- if message_start == -1:
- raise _DecodeError('MessageSet item missing message.')
- extension = message.Extensions._FindExtensionByNumber(type_id)
- # pylint: disable=protected-access
- if extension is not None:
- value = field_dict.get(extension)
- if value is None:
- message_type = extension.message_type
- if not hasattr(message_type, '_concrete_class'):
- # pylint: disable=protected-access
- message._FACTORY.GetPrototype(message_type)
- value = field_dict.setdefault(
- extension, message_type._concrete_class())
- if value._InternalParse(buffer, message_start,message_end) != message_end:
- # The only reason _InternalParse would return early is if it encountered
- # an end-group tag.
- raise _DecodeError('Unexpected end-group tag.')
- else:
- if not message._unknown_fields:
- message._unknown_fields = []
- message._unknown_fields.append(
- (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
- if message._unknown_field_set is None:
- message._unknown_field_set = containers.UnknownFieldSet()
- message._unknown_field_set._add(
- type_id,
- wire_format.WIRETYPE_LENGTH_DELIMITED,
- buffer[message_start:message_end].tobytes())
- # pylint: enable=protected-access
- return pos
- return DecodeItem
- # --------------------------------------------------------------------
- def MapDecoder(field_descriptor, new_default, is_message_map):
- """Returns a decoder for a map field."""
- key = field_descriptor
- tag_bytes = encoder.TagBytes(field_descriptor.number,
- wire_format.WIRETYPE_LENGTH_DELIMITED)
- tag_len = len(tag_bytes)
- local_DecodeVarint = _DecodeVarint
- # Can't read _concrete_class yet; might not be initialized.
- message_type = field_descriptor.message_type
- def DecodeMap(buffer, pos, end, message, field_dict):
- submsg = message_type._concrete_class()
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- # Read length.
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- # Read sub-message.
- submsg.Clear()
- if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
- # The only reason _InternalParse would return early is if it
- # encountered an end-group tag.
- raise _DecodeError('Unexpected end-group tag.')
- if is_message_map:
- value[submsg.key].CopyFrom(submsg.value)
- else:
- value[submsg.key] = submsg.value
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeMap
- # --------------------------------------------------------------------
- # Optimization is not as heavy here because calls to SkipField() are rare,
- # except for handling end-group tags.
- def _SkipVarint(buffer, pos, end):
- """Skip a varint value. Returns the new position."""
- # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
- # With this code, ord(b'') raises TypeError. Both are handled in
- # python_message.py to generate a 'Truncated message' error.
- while ord(buffer[pos:pos+1].tobytes()) & 0x80:
- pos += 1
- pos += 1
- if pos > end:
- raise _DecodeError('Truncated message.')
- return pos
- def _SkipFixed64(buffer, pos, end):
- """Skip a fixed64 value. Returns the new position."""
- pos += 8
- if pos > end:
- raise _DecodeError('Truncated message.')
- return pos
- def _DecodeFixed64(buffer, pos):
- """Decode a fixed64."""
- new_pos = pos + 8
- return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
- def _SkipLengthDelimited(buffer, pos, end):
- """Skip a length-delimited value. Returns the new position."""
- (size, pos) = _DecodeVarint(buffer, pos)
- pos += size
- if pos > end:
- raise _DecodeError('Truncated message.')
- return pos
- def _SkipGroup(buffer, pos, end):
- """Skip sub-group. Returns the new position."""
- while 1:
- (tag_bytes, pos) = ReadTag(buffer, pos)
- new_pos = SkipField(buffer, pos, end, tag_bytes)
- if new_pos == -1:
- return pos
- pos = new_pos
- def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
- """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
- unknown_field_set = containers.UnknownFieldSet()
- while end_pos is None or pos < end_pos:
- (tag_bytes, pos) = ReadTag(buffer, pos)
- (tag, _) = _DecodeVarint(tag_bytes, 0)
- field_number, wire_type = wire_format.UnpackTag(tag)
- if wire_type == wire_format.WIRETYPE_END_GROUP:
- break
- (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
- # pylint: disable=protected-access
- unknown_field_set._add(field_number, wire_type, data)
- return (unknown_field_set, pos)
- def _DecodeUnknownField(buffer, pos, wire_type):
- """Decode a unknown field. Returns the UnknownField and new position."""
- if wire_type == wire_format.WIRETYPE_VARINT:
- (data, pos) = _DecodeVarint(buffer, pos)
- elif wire_type == wire_format.WIRETYPE_FIXED64:
- (data, pos) = _DecodeFixed64(buffer, pos)
- elif wire_type == wire_format.WIRETYPE_FIXED32:
- (data, pos) = _DecodeFixed32(buffer, pos)
- elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
- (size, pos) = _DecodeVarint(buffer, pos)
- data = buffer[pos:pos+size].tobytes()
- pos += size
- elif wire_type == wire_format.WIRETYPE_START_GROUP:
- (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
- elif wire_type == wire_format.WIRETYPE_END_GROUP:
- return (0, -1)
- else:
- raise _DecodeError('Wrong wire type in tag.')
- return (data, pos)
- def _EndGroup(buffer, pos, end):
- """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
- return -1
- def _SkipFixed32(buffer, pos, end):
- """Skip a fixed32 value. Returns the new position."""
- pos += 4
- if pos > end:
- raise _DecodeError('Truncated message.')
- return pos
- def _DecodeFixed32(buffer, pos):
- """Decode a fixed32."""
- new_pos = pos + 4
- return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
- def _RaiseInvalidWireType(buffer, pos, end):
- """Skip function for unknown wire types. Raises an exception."""
- raise _DecodeError('Tag had invalid wire type.')
- def _FieldSkipper():
- """Constructs the SkipField function."""
- WIRETYPE_TO_SKIPPER = [
- _SkipVarint,
- _SkipFixed64,
- _SkipLengthDelimited,
- _SkipGroup,
- _EndGroup,
- _SkipFixed32,
- _RaiseInvalidWireType,
- _RaiseInvalidWireType,
- ]
- wiretype_mask = wire_format.TAG_TYPE_MASK
- def SkipField(buffer, pos, end, tag_bytes):
- """Skips a field with the specified tag.
- |pos| should point to the byte immediately after the tag.
- Returns:
- The new position (after the tag value), or -1 if the tag is an end-group
- tag (in which case the calling loop should break).
- """
- # The wire type is always in the first byte since varints are little-endian.
- wire_type = ord(tag_bytes[0:1]) & wiretype_mask
- return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
- return SkipField
- SkipField = _FieldSkipper()
|