|
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Protocol Buffers - Google's data interchange format
- # Copyright 2008 Google Inc. All rights reserved.
- # https://developers.google.com/protocol-buffers/
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are
- # met:
- #
- # * Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # * Redistributions in binary form must reproduce the above
- # copyright notice, this list of conditions and the following disclaimer
- # in the documentation and/or other materials provided with the
- # distribution.
- # * Neither the name of Google Inc. nor the names of its
- # contributors may be used to endorse or promote products derived from
- # this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- """Tests python protocol buffers against the golden message.
- Note that the golden messages exercise every known field type, thus this
- test ends up exercising and verifying nearly all of the parsing and
- serialization code in the whole library.
- TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of
- sense to call this a test of the "message" module, which only declares an
- abstract interface.
- """
- __author__ = 'gps@google.com (Gregory P. Smith)'
- import copy
- import math
- import operator
- import pickle
- import pydoc
- import six
- import sys
- import warnings
- try:
- # Since python 3
- import collections.abc as collections_abc
- except ImportError:
- # Won't work after python 3.8
- import collections as collections_abc
- try:
- import unittest2 as unittest # PY26
- except ImportError:
- import unittest
- try:
- cmp # Python 2
- except NameError:
- cmp = lambda x, y: (x > y) - (x < y) # Python 3
- from google.protobuf import map_proto2_unittest_pb2
- from google.protobuf import map_unittest_pb2
- from google.protobuf import unittest_pb2
- from google.protobuf import unittest_proto3_arena_pb2
- from google.protobuf import descriptor_pb2
- from google.protobuf import descriptor_pool
- from google.protobuf import message_factory
- from google.protobuf import text_format
- from google.protobuf.internal import api_implementation
- from google.protobuf.internal import encoder
- from google.protobuf.internal import more_extensions_pb2
- from google.protobuf.internal import packed_field_test_pb2
- from google.protobuf.internal import test_util
- from google.protobuf.internal import test_proto3_optional_pb2
- from google.protobuf.internal import testing_refleaks
- from google.protobuf import message
- from google.protobuf.internal import _parameterized
- UCS2_MAXUNICODE = 65535
- if six.PY3:
- long = int
- # Python pre-2.6 does not have isinf() or isnan() functions, so we have
- # to provide our own.
- def isnan(val):
- # NaN is never equal to itself.
- return val != val
- def isinf(val):
- # Infinity times zero equals NaN.
- return not isnan(val) and isnan(val * 0)
- def IsPosInf(val):
- return isinf(val) and (val > 0)
- def IsNegInf(val):
- return isinf(val) and (val < 0)
- warnings.simplefilter('error', DeprecationWarning)
- @_parameterized.named_parameters(
- ('_proto2', unittest_pb2),
- ('_proto3', unittest_proto3_arena_pb2))
- @testing_refleaks.TestCase
- class MessageTest(unittest.TestCase):
- def testBadUtf8String(self, message_module):
- if api_implementation.Type() != 'python':
- self.skipTest("Skipping testBadUtf8String, currently only the python "
- "api implementation raises UnicodeDecodeError when a "
- "string field contains bad utf-8.")
- bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
- with self.assertRaises(UnicodeDecodeError) as context:
- message_module.TestAllTypes.FromString(bad_utf8_data)
- self.assertIn('TestAllTypes.optional_string', str(context.exception))
- def testGoldenMessage(self, message_module):
- # Proto3 doesn't have the "default_foo" members or foreign enums,
- # and doesn't preserve unknown fields, so for proto3 we use a golden
- # message that doesn't have these fields set.
- if message_module is unittest_pb2:
- golden_data = test_util.GoldenFileData(
- 'golden_message_oneof_implemented')
- else:
- golden_data = test_util.GoldenFileData('golden_message_proto3')
- golden_message = message_module.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- if message_module is unittest_pb2:
- test_util.ExpectAllFieldsSet(self, golden_message)
- self.assertEqual(golden_data, golden_message.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testGoldenPackedMessage(self, message_module):
- golden_data = test_util.GoldenFileData('golden_packed_fields_message')
- golden_message = message_module.TestPackedTypes()
- parsed_bytes = golden_message.ParseFromString(golden_data)
- all_set = message_module.TestPackedTypes()
- test_util.SetAllPackedFields(all_set)
- self.assertEqual(parsed_bytes, len(golden_data))
- self.assertEqual(all_set, golden_message)
- self.assertEqual(golden_data, all_set.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testParseErrors(self, message_module):
- msg = message_module.TestAllTypes()
- self.assertRaises(TypeError, msg.FromString, 0)
- self.assertRaises(Exception, msg.FromString, '0')
- # TODO(jieluo): Fix cpp extension to raise error instead of warning.
- # b/27494216
- end_tag = encoder.TagBytes(1, 4)
- if api_implementation.Type() == 'python':
- with self.assertRaises(message.DecodeError) as context:
- msg.FromString(end_tag)
- self.assertEqual('Unexpected end-group tag.', str(context.exception))
- # Field number 0 is illegal.
- self.assertRaises(message.DecodeError, msg.FromString, b'\3\4')
- def testDeterminismParameters(self, message_module):
- # This message is always deterministically serialized, even if determinism
- # is disabled, so we can use it to verify that all the determinism
- # parameters work correctly.
- golden_data = (b'\xe2\x02\nOne string'
- b'\xe2\x02\nTwo string'
- b'\xe2\x02\nRed string'
- b'\xe2\x02\x0bBlue string')
- golden_message = message_module.TestAllTypes()
- golden_message.repeated_string.extend([
- 'One string',
- 'Two string',
- 'Red string',
- 'Blue string',
- ])
- self.assertEqual(golden_data,
- golden_message.SerializeToString(deterministic=None))
- self.assertEqual(golden_data,
- golden_message.SerializeToString(deterministic=False))
- self.assertEqual(golden_data,
- golden_message.SerializeToString(deterministic=True))
- class BadArgError(Exception):
- pass
- class BadArg(object):
- def __nonzero__(self):
- raise BadArgError()
- def __bool__(self):
- raise BadArgError()
- with self.assertRaises(BadArgError):
- golden_message.SerializeToString(deterministic=BadArg())
- def testPickleSupport(self, message_module):
- golden_data = test_util.GoldenFileData('golden_message')
- golden_message = message_module.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- pickled_message = pickle.dumps(golden_message)
- unpickled_message = pickle.loads(pickled_message)
- self.assertEqual(unpickled_message, golden_message)
- def testPickleNestedMessage(self, message_module):
- golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1)
- pickled_message = pickle.dumps(golden_message)
- unpickled_message = pickle.loads(pickled_message)
- self.assertEqual(unpickled_message, golden_message)
- def testPickleNestedNestedMessage(self, message_module):
- cls = message_module.TestPickleNestedMessage.NestedMessage
- golden_message = cls.NestedNestedMessage(cc=1)
- pickled_message = pickle.dumps(golden_message)
- unpickled_message = pickle.loads(pickled_message)
- self.assertEqual(unpickled_message, golden_message)
- def testPositiveInfinity(self, message_module):
- if message_module is unittest_pb2:
- golden_data = (b'\x5D\x00\x00\x80\x7F'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
- b'\xCD\x02\x00\x00\x80\x7F'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
- else:
- golden_data = (b'\x5D\x00\x00\x80\x7F'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
- b'\xCA\x02\x04\x00\x00\x80\x7F'
- b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
- golden_message = message_module.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- self.assertTrue(IsPosInf(golden_message.optional_float))
- self.assertTrue(IsPosInf(golden_message.optional_double))
- self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
- self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
- self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNegativeInfinity(self, message_module):
- if message_module is unittest_pb2:
- golden_data = (b'\x5D\x00\x00\x80\xFF'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
- b'\xCD\x02\x00\x00\x80\xFF'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
- else:
- golden_data = (b'\x5D\x00\x00\x80\xFF'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
- b'\xCA\x02\x04\x00\x00\x80\xFF'
- b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
- golden_message = message_module.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- self.assertTrue(IsNegInf(golden_message.optional_float))
- self.assertTrue(IsNegInf(golden_message.optional_double))
- self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
- self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
- self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNotANumber(self, message_module):
- golden_data = (b'\x5D\x00\x00\xC0\x7F'
- b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
- b'\xCD\x02\x00\x00\xC0\x7F'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
- golden_message = message_module.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- self.assertTrue(isnan(golden_message.optional_float))
- self.assertTrue(isnan(golden_message.optional_double))
- self.assertTrue(isnan(golden_message.repeated_float[0]))
- self.assertTrue(isnan(golden_message.repeated_double[0]))
- # The protocol buffer may serialize to any one of multiple different
- # representations of a NaN. Rather than verify a specific representation,
- # verify the serialized string can be converted into a correctly
- # behaving protocol buffer.
- serialized = golden_message.SerializeToString()
- message = message_module.TestAllTypes()
- message.ParseFromString(serialized)
- self.assertTrue(isnan(message.optional_float))
- self.assertTrue(isnan(message.optional_double))
- self.assertTrue(isnan(message.repeated_float[0]))
- self.assertTrue(isnan(message.repeated_double[0]))
- def testPositiveInfinityPacked(self, message_module):
- golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
- b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
- golden_message = message_module.TestPackedTypes()
- golden_message.ParseFromString(golden_data)
- self.assertTrue(IsPosInf(golden_message.packed_float[0]))
- self.assertTrue(IsPosInf(golden_message.packed_double[0]))
- self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNegativeInfinityPacked(self, message_module):
- golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
- b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
- golden_message = message_module.TestPackedTypes()
- golden_message.ParseFromString(golden_data)
- self.assertTrue(IsNegInf(golden_message.packed_float[0]))
- self.assertTrue(IsNegInf(golden_message.packed_double[0]))
- self.assertEqual(golden_data, golden_message.SerializeToString())
- def testNotANumberPacked(self, message_module):
- golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
- b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
- golden_message = message_module.TestPackedTypes()
- golden_message.ParseFromString(golden_data)
- self.assertTrue(isnan(golden_message.packed_float[0]))
- self.assertTrue(isnan(golden_message.packed_double[0]))
- serialized = golden_message.SerializeToString()
- message = message_module.TestPackedTypes()
- message.ParseFromString(serialized)
- self.assertTrue(isnan(message.packed_float[0]))
- self.assertTrue(isnan(message.packed_double[0]))
- def testExtremeFloatValues(self, message_module):
- message = message_module.TestAllTypes()
- # Most positive exponent, no significand bits set.
- kMostPosExponentNoSigBits = math.pow(2, 127)
- message.optional_float = kMostPosExponentNoSigBits
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
- # Most positive exponent, one significand bit set.
- kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
- message.optional_float = kMostPosExponentOneSigBit
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
- # Repeat last two cases with values of same magnitude, but negative.
- message.optional_float = -kMostPosExponentNoSigBits
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
- message.optional_float = -kMostPosExponentOneSigBit
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
- # Most negative exponent, no significand bits set.
- kMostNegExponentNoSigBits = math.pow(2, -127)
- message.optional_float = kMostNegExponentNoSigBits
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
- # Most negative exponent, one significand bit set.
- kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
- message.optional_float = kMostNegExponentOneSigBit
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
- # Repeat last two cases with values of the same magnitude, but negative.
- message.optional_float = -kMostNegExponentNoSigBits
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
- message.optional_float = -kMostNegExponentOneSigBit
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
- # Max 4 bytes float value
- max_float = float.fromhex('0x1.fffffep+127')
- message.optional_float = max_float
- self.assertAlmostEqual(message.optional_float, max_float)
- serialized_data = message.SerializeToString()
- message.ParseFromString(serialized_data)
- self.assertAlmostEqual(message.optional_float, max_float)
- # Test set double to float field.
- message.optional_float = 3.4028235e+39
- self.assertEqual(message.optional_float, float('inf'))
- serialized_data = message.SerializeToString()
- message.ParseFromString(serialized_data)
- self.assertEqual(message.optional_float, float('inf'))
- message.optional_float = -3.4028235e+39
- self.assertEqual(message.optional_float, float('-inf'))
- message.optional_float = 1.4028235e-39
- self.assertAlmostEqual(message.optional_float, 1.4028235e-39)
- def testExtremeDoubleValues(self, message_module):
- message = message_module.TestAllTypes()
- # Most positive exponent, no significand bits set.
- kMostPosExponentNoSigBits = math.pow(2, 1023)
- message.optional_double = kMostPosExponentNoSigBits
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
- # Most positive exponent, one significand bit set.
- kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
- message.optional_double = kMostPosExponentOneSigBit
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
- # Repeat last two cases with values of same magnitude, but negative.
- message.optional_double = -kMostPosExponentNoSigBits
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
- message.optional_double = -kMostPosExponentOneSigBit
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
- # Most negative exponent, no significand bits set.
- kMostNegExponentNoSigBits = math.pow(2, -1023)
- message.optional_double = kMostNegExponentNoSigBits
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
- # Most negative exponent, one significand bit set.
- kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
- message.optional_double = kMostNegExponentOneSigBit
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
- # Repeat last two cases with values of the same magnitude, but negative.
- message.optional_double = -kMostNegExponentNoSigBits
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
- message.optional_double = -kMostNegExponentOneSigBit
- message.ParseFromString(message.SerializeToString())
- self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
- def testFloatPrinting(self, message_module):
- message = message_module.TestAllTypes()
- message.optional_float = 2.0
- self.assertEqual(str(message), 'optional_float: 2.0\n')
- def testHighPrecisionFloatPrinting(self, message_module):
- msg = message_module.TestAllTypes()
- msg.optional_float = 0.12345678912345678
- old_float = msg.optional_float
- msg.ParseFromString(msg.SerializeToString())
- self.assertEqual(old_float, msg.optional_float)
- def testHighPrecisionDoublePrinting(self, message_module):
- msg = message_module.TestAllTypes()
- msg.optional_double = 0.12345678912345678
- if sys.version_info >= (3,):
- self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n')
- else:
- self.assertEqual(str(msg), 'optional_double: 0.123456789123\n')
- def testUnknownFieldPrinting(self, message_module):
- populated = message_module.TestAllTypes()
- test_util.SetAllNonLazyFields(populated)
- empty = message_module.TestEmptyMessage()
- empty.ParseFromString(populated.SerializeToString())
- self.assertEqual(str(empty), '')
- def testAppendRepeatedCompositeField(self, message_module):
- msg = message_module.TestAllTypes()
- msg.repeated_nested_message.append(
- message_module.TestAllTypes.NestedMessage(bb=1))
- nested = message_module.TestAllTypes.NestedMessage(bb=2)
- msg.repeated_nested_message.append(nested)
- try:
- msg.repeated_nested_message.append(1)
- except TypeError:
- pass
- self.assertEqual(2, len(msg.repeated_nested_message))
- self.assertEqual([1, 2],
- [m.bb for m in msg.repeated_nested_message])
- def testInsertRepeatedCompositeField(self, message_module):
- msg = message_module.TestAllTypes()
- msg.repeated_nested_message.insert(
- -1, message_module.TestAllTypes.NestedMessage(bb=1))
- sub_msg = msg.repeated_nested_message[0]
- msg.repeated_nested_message.insert(
- 0, message_module.TestAllTypes.NestedMessage(bb=2))
- msg.repeated_nested_message.insert(
- 99, message_module.TestAllTypes.NestedMessage(bb=3))
- msg.repeated_nested_message.insert(
- -2, message_module.TestAllTypes.NestedMessage(bb=-1))
- msg.repeated_nested_message.insert(
- -1000, message_module.TestAllTypes.NestedMessage(bb=-1000))
- try:
- msg.repeated_nested_message.insert(1, 999)
- except TypeError:
- pass
- self.assertEqual(5, len(msg.repeated_nested_message))
- self.assertEqual([-1000, 2, -1, 1, 3],
- [m.bb for m in msg.repeated_nested_message])
- self.assertEqual(str(msg),
- 'repeated_nested_message {\n'
- ' bb: -1000\n'
- '}\n'
- 'repeated_nested_message {\n'
- ' bb: 2\n'
- '}\n'
- 'repeated_nested_message {\n'
- ' bb: -1\n'
- '}\n'
- 'repeated_nested_message {\n'
- ' bb: 1\n'
- '}\n'
- 'repeated_nested_message {\n'
- ' bb: 3\n'
- '}\n')
- self.assertEqual(sub_msg.bb, 1)
- def testMergeFromRepeatedField(self, message_module):
- msg = message_module.TestAllTypes()
- msg.repeated_int32.append(1)
- msg.repeated_int32.append(3)
- msg.repeated_nested_message.add(bb=1)
- msg.repeated_nested_message.add(bb=2)
- other_msg = message_module.TestAllTypes()
- other_msg.repeated_nested_message.add(bb=3)
- other_msg.repeated_nested_message.add(bb=4)
- other_msg.repeated_int32.append(5)
- other_msg.repeated_int32.append(7)
- msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
- self.assertEqual(4, len(msg.repeated_int32))
- msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
- self.assertEqual([1, 2, 3, 4],
- [m.bb for m in msg.repeated_nested_message])
- def testAddWrongRepeatedNestedField(self, message_module):
- msg = message_module.TestAllTypes()
- try:
- msg.repeated_nested_message.add('wrong')
- except TypeError:
- pass
- try:
- msg.repeated_nested_message.add(value_field='wrong')
- except ValueError:
- pass
- self.assertEqual(len(msg.repeated_nested_message), 0)
- def testRepeatedContains(self, message_module):
- msg = message_module.TestAllTypes()
- msg.repeated_int32.extend([1, 2, 3])
- self.assertIn(2, msg.repeated_int32)
- self.assertNotIn(0, msg.repeated_int32)
- msg.repeated_nested_message.add(bb=1)
- sub_msg1 = msg.repeated_nested_message[0]
- sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2)
- sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3)
- msg.repeated_nested_message.append(sub_msg2)
- msg.repeated_nested_message.insert(0, sub_msg3)
- self.assertIn(sub_msg1, msg.repeated_nested_message)
- self.assertIn(sub_msg2, msg.repeated_nested_message)
- self.assertIn(sub_msg3, msg.repeated_nested_message)
- def testRepeatedScalarIterable(self, message_module):
- msg = message_module.TestAllTypes()
- msg.repeated_int32.extend([1, 2, 3])
- add = 0
- for item in msg.repeated_int32:
- add += item
- self.assertEqual(add, 6)
- def testRepeatedNestedFieldIteration(self, message_module):
- msg = message_module.TestAllTypes()
- msg.repeated_nested_message.add(bb=1)
- msg.repeated_nested_message.add(bb=2)
- msg.repeated_nested_message.add(bb=3)
- msg.repeated_nested_message.add(bb=4)
- self.assertEqual([1, 2, 3, 4],
- [m.bb for m in msg.repeated_nested_message])
- self.assertEqual([4, 3, 2, 1],
- [m.bb for m in reversed(msg.repeated_nested_message)])
- self.assertEqual([4, 3, 2, 1],
- [m.bb for m in msg.repeated_nested_message[::-1]])
- def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
- """Check some different types with the default comparator."""
- message = message_module.TestAllTypes()
- # TODO(mattp): would testing more scalar types strengthen test?
- message.repeated_int32.append(1)
- message.repeated_int32.append(3)
- message.repeated_int32.append(2)
- message.repeated_int32.sort()
- self.assertEqual(message.repeated_int32[0], 1)
- self.assertEqual(message.repeated_int32[1], 2)
- self.assertEqual(message.repeated_int32[2], 3)
- self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
- message.repeated_float.append(1.1)
- message.repeated_float.append(1.3)
- message.repeated_float.append(1.2)
- message.repeated_float.sort()
- self.assertAlmostEqual(message.repeated_float[0], 1.1)
- self.assertAlmostEqual(message.repeated_float[1], 1.2)
- self.assertAlmostEqual(message.repeated_float[2], 1.3)
- message.repeated_string.append('a')
- message.repeated_string.append('c')
- message.repeated_string.append('b')
- message.repeated_string.sort()
- self.assertEqual(message.repeated_string[0], 'a')
- self.assertEqual(message.repeated_string[1], 'b')
- self.assertEqual(message.repeated_string[2], 'c')
- self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
- message.repeated_bytes.append(b'a')
- message.repeated_bytes.append(b'c')
- message.repeated_bytes.append(b'b')
- message.repeated_bytes.sort()
- self.assertEqual(message.repeated_bytes[0], b'a')
- self.assertEqual(message.repeated_bytes[1], b'b')
- self.assertEqual(message.repeated_bytes[2], b'c')
- self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
- def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
- """Check some different types with custom comparator."""
- message = message_module.TestAllTypes()
- message.repeated_int32.append(-3)
- message.repeated_int32.append(-2)
- message.repeated_int32.append(-1)
- message.repeated_int32.sort(key=abs)
- self.assertEqual(message.repeated_int32[0], -1)
- self.assertEqual(message.repeated_int32[1], -2)
- self.assertEqual(message.repeated_int32[2], -3)
- message.repeated_string.append('aaa')
- message.repeated_string.append('bb')
- message.repeated_string.append('c')
- message.repeated_string.sort(key=len)
- self.assertEqual(message.repeated_string[0], 'c')
- self.assertEqual(message.repeated_string[1], 'bb')
- self.assertEqual(message.repeated_string[2], 'aaa')
- def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
- """Check passing a custom comparator to sort a repeated composite field."""
- message = message_module.TestAllTypes()
- message.repeated_nested_message.add().bb = 1
- message.repeated_nested_message.add().bb = 3
- message.repeated_nested_message.add().bb = 2
- message.repeated_nested_message.add().bb = 6
- message.repeated_nested_message.add().bb = 5
- message.repeated_nested_message.add().bb = 4
- message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
- self.assertEqual(message.repeated_nested_message[0].bb, 1)
- self.assertEqual(message.repeated_nested_message[1].bb, 2)
- self.assertEqual(message.repeated_nested_message[2].bb, 3)
- self.assertEqual(message.repeated_nested_message[3].bb, 4)
- self.assertEqual(message.repeated_nested_message[4].bb, 5)
- self.assertEqual(message.repeated_nested_message[5].bb, 6)
- self.assertEqual(str(message.repeated_nested_message),
- '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
- def testSortingRepeatedCompositeFieldsStable(self, message_module):
- """Check passing a custom comparator to sort a repeated composite field."""
- message = message_module.TestAllTypes()
- message.repeated_nested_message.add().bb = 21
- message.repeated_nested_message.add().bb = 20
- message.repeated_nested_message.add().bb = 13
- message.repeated_nested_message.add().bb = 33
- message.repeated_nested_message.add().bb = 11
- message.repeated_nested_message.add().bb = 24
- message.repeated_nested_message.add().bb = 10
- message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
- self.assertEqual(
- [13, 11, 10, 21, 20, 24, 33],
- [n.bb for n in message.repeated_nested_message])
- # Make sure that for the C++ implementation, the underlying fields
- # are actually reordered.
- pb = message.SerializeToString()
- message.Clear()
- message.MergeFromString(pb)
- self.assertEqual(
- [13, 11, 10, 21, 20, 24, 33],
- [n.bb for n in message.repeated_nested_message])
- def testRepeatedCompositeFieldSortArguments(self, message_module):
- """Check sorting a repeated composite field using list.sort() arguments."""
- message = message_module.TestAllTypes()
- get_bb = operator.attrgetter('bb')
- cmp_bb = lambda a, b: cmp(a.bb, b.bb)
- message.repeated_nested_message.add().bb = 1
- message.repeated_nested_message.add().bb = 3
- message.repeated_nested_message.add().bb = 2
- message.repeated_nested_message.add().bb = 6
- message.repeated_nested_message.add().bb = 5
- message.repeated_nested_message.add().bb = 4
- message.repeated_nested_message.sort(key=get_bb)
- self.assertEqual([k.bb for k in message.repeated_nested_message],
- [1, 2, 3, 4, 5, 6])
- message.repeated_nested_message.sort(key=get_bb, reverse=True)
- self.assertEqual([k.bb for k in message.repeated_nested_message],
- [6, 5, 4, 3, 2, 1])
- if sys.version_info >= (3,): return # No cmp sorting in PY3.
- message.repeated_nested_message.sort(sort_function=cmp_bb)
- self.assertEqual([k.bb for k in message.repeated_nested_message],
- [1, 2, 3, 4, 5, 6])
- message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
- self.assertEqual([k.bb for k in message.repeated_nested_message],
- [6, 5, 4, 3, 2, 1])
- def testRepeatedScalarFieldSortArguments(self, message_module):
- """Check sorting a scalar field using list.sort() arguments."""
- message = message_module.TestAllTypes()
- message.repeated_int32.append(-3)
- message.repeated_int32.append(-2)
- message.repeated_int32.append(-1)
- message.repeated_int32.sort(key=abs)
- self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
- message.repeated_int32.sort(key=abs, reverse=True)
- self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
- if sys.version_info < (3,): # No cmp sorting in PY3.
- abs_cmp = lambda a, b: cmp(abs(a), abs(b))
- message.repeated_int32.sort(sort_function=abs_cmp)
- self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
- message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
- self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
- message.repeated_string.append('aaa')
- message.repeated_string.append('bb')
- message.repeated_string.append('c')
- message.repeated_string.sort(key=len)
- self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
- message.repeated_string.sort(key=len, reverse=True)
- self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
- if sys.version_info < (3,): # No cmp sorting in PY3.
- len_cmp = lambda a, b: cmp(len(a), len(b))
- message.repeated_string.sort(sort_function=len_cmp)
- self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
- message.repeated_string.sort(cmp=len_cmp, reverse=True)
- self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
- def testRepeatedFieldsComparable(self, message_module):
- m1 = message_module.TestAllTypes()
- m2 = message_module.TestAllTypes()
- m1.repeated_int32.append(0)
- m1.repeated_int32.append(1)
- m1.repeated_int32.append(2)
- m2.repeated_int32.append(0)
- m2.repeated_int32.append(1)
- m2.repeated_int32.append(2)
- m1.repeated_nested_message.add().bb = 1
- m1.repeated_nested_message.add().bb = 2
- m1.repeated_nested_message.add().bb = 3
- m2.repeated_nested_message.add().bb = 1
- m2.repeated_nested_message.add().bb = 2
- m2.repeated_nested_message.add().bb = 3
- if sys.version_info >= (3,): return # No cmp() in PY3.
- # These comparisons should not raise errors.
- _ = m1 < m2
- _ = m1.repeated_nested_message < m2.repeated_nested_message
- # Make sure cmp always works. If it wasn't defined, these would be
- # id() comparisons and would all fail.
- self.assertEqual(cmp(m1, m2), 0)
- self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
- self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
- self.assertEqual(cmp(m1.repeated_nested_message,
- m2.repeated_nested_message), 0)
- with self.assertRaises(TypeError):
- # Can't compare repeated composite containers to lists.
- cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
- # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
- def testRepeatedFieldsAreSequences(self, message_module):
- m = message_module.TestAllTypes()
- self.assertIsInstance(m.repeated_int32, collections_abc.MutableSequence)
- self.assertIsInstance(m.repeated_nested_message,
- collections_abc.MutableSequence)
- def testRepeatedFieldsNotHashable(self, message_module):
- m = message_module.TestAllTypes()
- with self.assertRaises(TypeError):
- hash(m.repeated_int32)
- with self.assertRaises(TypeError):
- hash(m.repeated_nested_message)
- def testRepeatedFieldInsideNestedMessage(self, message_module):
- m = message_module.NestedTestAllTypes()
- m.payload.repeated_int32.extend([])
- self.assertTrue(m.HasField('payload'))
- def testMergeFrom(self, message_module):
- m1 = message_module.TestAllTypes()
- m2 = message_module.TestAllTypes()
- # Cpp extension will lazily create a sub message which is immutable.
- nested = m1.optional_nested_message
- self.assertEqual(0, nested.bb)
- m2.optional_nested_message.bb = 1
- # Make sure cmessage pointing to a mutable message after merge instead of
- # the lazily created message.
- m1.MergeFrom(m2)
- self.assertEqual(1, nested.bb)
- # Test more nested sub message.
- msg1 = message_module.NestedTestAllTypes()
- msg2 = message_module.NestedTestAllTypes()
- nested = msg1.child.payload.optional_nested_message
- self.assertEqual(0, nested.bb)
- msg2.child.payload.optional_nested_message.bb = 1
- msg1.MergeFrom(msg2)
- self.assertEqual(1, nested.bb)
- # Test repeated field.
- self.assertEqual(msg1.payload.repeated_nested_message,
- msg1.payload.repeated_nested_message)
- nested = msg2.payload.repeated_nested_message.add()
- nested.bb = 1
- msg1.MergeFrom(msg2)
- self.assertEqual(1, len(msg1.payload.repeated_nested_message))
- self.assertEqual(1, nested.bb)
- def testMergeFromString(self, message_module):
- m1 = message_module.TestAllTypes()
- m2 = message_module.TestAllTypes()
- # Cpp extension will lazily create a sub message which is immutable.
- self.assertEqual(0, m1.optional_nested_message.bb)
- m2.optional_nested_message.bb = 1
- # Make sure cmessage pointing to a mutable message after merge instead of
- # the lazily created message.
- m1.MergeFromString(m2.SerializeToString())
- self.assertEqual(1, m1.optional_nested_message.bb)
- def testMergeFromStringUsingMemoryView(self, message_module):
- m2 = message_module.TestAllTypes()
- m2.optional_string = 'scalar string'
- m2.repeated_string.append('repeated string')
- m2.optional_bytes = b'scalar bytes'
- m2.repeated_bytes.append(b'repeated bytes')
- serialized = m2.SerializeToString()
- memview = memoryview(serialized)
- m1 = message_module.TestAllTypes.FromString(memview)
- self.assertEqual(m1.optional_bytes, b'scalar bytes')
- self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
- self.assertEqual(m1.optional_string, 'scalar string')
- self.assertEqual(m1.repeated_string, ['repeated string'])
- # Make sure that the memoryview was correctly converted to bytes, and
- # that a sub-sliced memoryview is not being used.
- self.assertIsInstance(m1.optional_bytes, bytes)
- self.assertIsInstance(m1.repeated_bytes[0], bytes)
- self.assertIsInstance(m1.optional_string, six.text_type)
- self.assertIsInstance(m1.repeated_string[0], six.text_type)
- def testMergeFromEmpty(self, message_module):
- m1 = message_module.TestAllTypes()
- # Cpp extension will lazily create a sub message which is immutable.
- self.assertEqual(0, m1.optional_nested_message.bb)
- self.assertFalse(m1.HasField('optional_nested_message'))
- # Make sure the sub message is still immutable after merge from empty.
- m1.MergeFromString(b'') # field state should not change
- self.assertFalse(m1.HasField('optional_nested_message'))
- def ensureNestedMessageExists(self, msg, attribute):
- """Make sure that a nested message object exists.
- As soon as a nested message attribute is accessed, it will be present in the
- _fields dict, without being marked as actually being set.
- """
- getattr(msg, attribute)
- self.assertFalse(msg.HasField(attribute))
- def testOneofGetCaseNonexistingField(self, message_module):
- m = message_module.TestAllTypes()
- self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
- self.assertRaises(Exception, m.WhichOneof, 0)
- def testOneofDefaultValues(self, message_module):
- m = message_module.TestAllTypes()
- self.assertIs(None, m.WhichOneof('oneof_field'))
- self.assertFalse(m.HasField('oneof_field'))
- self.assertFalse(m.HasField('oneof_uint32'))
- # Oneof is set even when setting it to a default value.
- m.oneof_uint32 = 0
- self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
- self.assertTrue(m.HasField('oneof_field'))
- self.assertTrue(m.HasField('oneof_uint32'))
- self.assertFalse(m.HasField('oneof_string'))
- m.oneof_string = ""
- self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
- self.assertTrue(m.HasField('oneof_string'))
- self.assertFalse(m.HasField('oneof_uint32'))
- def testOneofSemantics(self, message_module):
- m = message_module.TestAllTypes()
- self.assertIs(None, m.WhichOneof('oneof_field'))
- m.oneof_uint32 = 11
- self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
- self.assertTrue(m.HasField('oneof_uint32'))
- m.oneof_string = u'foo'
- self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
- self.assertFalse(m.HasField('oneof_uint32'))
- self.assertTrue(m.HasField('oneof_string'))
- # Read nested message accessor without accessing submessage.
- m.oneof_nested_message
- self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
- self.assertTrue(m.HasField('oneof_string'))
- self.assertFalse(m.HasField('oneof_nested_message'))
- # Read accessor of nested message without accessing submessage.
- m.oneof_nested_message.bb
- self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
- self.assertTrue(m.HasField('oneof_string'))
- self.assertFalse(m.HasField('oneof_nested_message'))
- m.oneof_nested_message.bb = 11
- self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
- self.assertFalse(m.HasField('oneof_string'))
- self.assertTrue(m.HasField('oneof_nested_message'))
- m.oneof_bytes = b'bb'
- self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
- self.assertFalse(m.HasField('oneof_nested_message'))
- self.assertTrue(m.HasField('oneof_bytes'))
- def testOneofCompositeFieldReadAccess(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- self.ensureNestedMessageExists(m, 'oneof_nested_message')
- self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
- self.assertEqual(11, m.oneof_uint32)
- def testOneofWhichOneof(self, message_module):
- m = message_module.TestAllTypes()
- self.assertIs(None, m.WhichOneof('oneof_field'))
- if message_module is unittest_pb2:
- self.assertFalse(m.HasField('oneof_field'))
- m.oneof_uint32 = 11
- self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
- if message_module is unittest_pb2:
- self.assertTrue(m.HasField('oneof_field'))
- m.oneof_bytes = b'bb'
- self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
- m.ClearField('oneof_bytes')
- self.assertIs(None, m.WhichOneof('oneof_field'))
- if message_module is unittest_pb2:
- self.assertFalse(m.HasField('oneof_field'))
- def testOneofClearField(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- m.ClearField('oneof_field')
- if message_module is unittest_pb2:
- self.assertFalse(m.HasField('oneof_field'))
- self.assertFalse(m.HasField('oneof_uint32'))
- self.assertIs(None, m.WhichOneof('oneof_field'))
- def testOneofClearSetField(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- m.ClearField('oneof_uint32')
- if message_module is unittest_pb2:
- self.assertFalse(m.HasField('oneof_field'))
- self.assertFalse(m.HasField('oneof_uint32'))
- self.assertIs(None, m.WhichOneof('oneof_field'))
- def testOneofClearUnsetField(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- self.ensureNestedMessageExists(m, 'oneof_nested_message')
- m.ClearField('oneof_nested_message')
- self.assertEqual(11, m.oneof_uint32)
- if message_module is unittest_pb2:
- self.assertTrue(m.HasField('oneof_field'))
- self.assertTrue(m.HasField('oneof_uint32'))
- self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
- def testOneofDeserialize(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- m2 = message_module.TestAllTypes()
- m2.ParseFromString(m.SerializeToString())
- self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
- def testOneofCopyFrom(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- m2 = message_module.TestAllTypes()
- m2.CopyFrom(m)
- self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
- def testOneofNestedMergeFrom(self, message_module):
- m = message_module.NestedTestAllTypes()
- m.payload.oneof_uint32 = 11
- m2 = message_module.NestedTestAllTypes()
- m2.payload.oneof_bytes = b'bb'
- m2.child.payload.oneof_bytes = b'bb'
- m2.MergeFrom(m)
- self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
- self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
- def testOneofMessageMergeFrom(self, message_module):
- m = message_module.NestedTestAllTypes()
- m.payload.oneof_nested_message.bb = 11
- m.child.payload.oneof_nested_message.bb = 12
- m2 = message_module.NestedTestAllTypes()
- m2.payload.oneof_uint32 = 13
- m2.MergeFrom(m)
- self.assertEqual('oneof_nested_message',
- m2.payload.WhichOneof('oneof_field'))
- self.assertEqual('oneof_nested_message',
- m2.child.payload.WhichOneof('oneof_field'))
- def testOneofNestedMessageInit(self, message_module):
- m = message_module.TestAllTypes(
- oneof_nested_message=message_module.TestAllTypes.NestedMessage())
- self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
- def testOneofClear(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- m.Clear()
- self.assertIsNone(m.WhichOneof('oneof_field'))
- m.oneof_bytes = b'bb'
- self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
- def testAssignByteStringToUnicodeField(self, message_module):
- """Assigning a byte string to a string field should result
- in the value being converted to a Unicode string."""
- m = message_module.TestAllTypes()
- m.optional_string = str('')
- self.assertIsInstance(m.optional_string, six.text_type)
- def testLongValuedSlice(self, message_module):
- """It should be possible to use long-valued indices in slices.
- This didn't used to work in the v2 C++ implementation.
- """
- m = message_module.TestAllTypes()
- # Repeated scalar
- m.repeated_int32.append(1)
- sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
- self.assertEqual(len(m.repeated_int32), len(sl))
- # Repeated composite
- m.repeated_nested_message.add().bb = 3
- sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
- self.assertEqual(len(m.repeated_nested_message), len(sl))
- def testExtendShouldNotSwallowExceptions(self, message_module):
- """This didn't use to work in the v2 C++ implementation."""
- m = message_module.TestAllTypes()
- with self.assertRaises(NameError) as _:
- m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable
- with self.assertRaises(NameError) as _:
- m.repeated_nested_enum.extend(
- a for i in range(10)) # pylint: disable=undefined-variable
- FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
- def testExtendInt32WithNothing(self, message_module):
- """Test no-ops extending repeated int32 fields."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_int32)
- # TODO(ptucker): Deprecate this behavior. b/18413862
- for falsy_value in MessageTest.FALSY_VALUES:
- m.repeated_int32.extend(falsy_value)
- self.assertSequenceEqual([], m.repeated_int32)
- m.repeated_int32.extend([])
- self.assertSequenceEqual([], m.repeated_int32)
- def testExtendFloatWithNothing(self, message_module):
- """Test no-ops extending repeated float fields."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_float)
- # TODO(ptucker): Deprecate this behavior. b/18413862
- for falsy_value in MessageTest.FALSY_VALUES:
- m.repeated_float.extend(falsy_value)
- self.assertSequenceEqual([], m.repeated_float)
- m.repeated_float.extend([])
- self.assertSequenceEqual([], m.repeated_float)
- def testExtendStringWithNothing(self, message_module):
- """Test no-ops extending repeated string fields."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_string)
- # TODO(ptucker): Deprecate this behavior. b/18413862
- for falsy_value in MessageTest.FALSY_VALUES:
- m.repeated_string.extend(falsy_value)
- self.assertSequenceEqual([], m.repeated_string)
- m.repeated_string.extend([])
- self.assertSequenceEqual([], m.repeated_string)
- def testExtendInt32WithPythonList(self, message_module):
- """Test extending repeated int32 fields with python lists."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_int32)
- m.repeated_int32.extend([0])
- self.assertSequenceEqual([0], m.repeated_int32)
- m.repeated_int32.extend([1, 2])
- self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
- m.repeated_int32.extend([3, 4])
- self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
- def testExtendFloatWithPythonList(self, message_module):
- """Test extending repeated float fields with python lists."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_float)
- m.repeated_float.extend([0.0])
- self.assertSequenceEqual([0.0], m.repeated_float)
- m.repeated_float.extend([1.0, 2.0])
- self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
- m.repeated_float.extend([3.0, 4.0])
- self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
- def testExtendStringWithPythonList(self, message_module):
- """Test extending repeated string fields with python lists."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_string)
- m.repeated_string.extend([''])
- self.assertSequenceEqual([''], m.repeated_string)
- m.repeated_string.extend(['11', '22'])
- self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
- m.repeated_string.extend(['33', '44'])
- self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
- def testExtendStringWithString(self, message_module):
- """Test extending repeated string fields with characters from a string."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_string)
- m.repeated_string.extend('abc')
- self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
- class TestIterable(object):
- """This iterable object mimics the behavior of numpy.array.
- __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
- """
- def __init__(self, values=None):
- self._list = values or []
- def __nonzero__(self):
- size = len(self._list)
- if size == 0:
- return False
- if size == 1:
- return bool(self._list[0])
- raise ValueError('Truth value is ambiguous.')
- def __len__(self):
- return len(self._list)
- def __iter__(self):
- return self._list.__iter__()
- def testExtendInt32WithIterable(self, message_module):
- """Test extending repeated int32 fields with iterable."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_int32)
- m.repeated_int32.extend(MessageTest.TestIterable([]))
- self.assertSequenceEqual([], m.repeated_int32)
- m.repeated_int32.extend(MessageTest.TestIterable([0]))
- self.assertSequenceEqual([0], m.repeated_int32)
- m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
- self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
- m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
- self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
- def testExtendFloatWithIterable(self, message_module):
- """Test extending repeated float fields with iterable."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_float)
- m.repeated_float.extend(MessageTest.TestIterable([]))
- self.assertSequenceEqual([], m.repeated_float)
- m.repeated_float.extend(MessageTest.TestIterable([0.0]))
- self.assertSequenceEqual([0.0], m.repeated_float)
- m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
- self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
- m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
- self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
- def testExtendStringWithIterable(self, message_module):
- """Test extending repeated string fields with iterable."""
- m = message_module.TestAllTypes()
- self.assertSequenceEqual([], m.repeated_string)
- m.repeated_string.extend(MessageTest.TestIterable([]))
- self.assertSequenceEqual([], m.repeated_string)
- m.repeated_string.extend(MessageTest.TestIterable(['']))
- self.assertSequenceEqual([''], m.repeated_string)
- m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
- self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
- m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
- self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
- def testPickleRepeatedScalarContainer(self, message_module):
- # TODO(tibell): The pure-Python implementation support pickling of
- # scalar containers in *some* cases. For now the cpp2 version
- # throws an exception to avoid a segfault. Investigate if we
- # want to support pickling of these fields.
- #
- # For more information see: https://b2.corp.google.com/u/0/issues/18677897
- if (api_implementation.Type() != 'cpp' or
- api_implementation.Version() == 2):
- return
- m = message_module.TestAllTypes()
- with self.assertRaises(pickle.PickleError) as _:
- pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
- def testSortEmptyRepeatedCompositeContainer(self, message_module):
- """Exercise a scenario that has led to segfaults in the past.
- """
- m = message_module.TestAllTypes()
- m.repeated_nested_message.sort()
- def testHasFieldOnRepeatedField(self, message_module):
- """Using HasField on a repeated field should raise an exception.
- """
- m = message_module.TestAllTypes()
- with self.assertRaises(ValueError) as _:
- m.HasField('repeated_int32')
- def testRepeatedScalarFieldPop(self, message_module):
- m = message_module.TestAllTypes()
- with self.assertRaises(IndexError) as _:
- m.repeated_int32.pop()
- m.repeated_int32.extend(range(5))
- self.assertEqual(4, m.repeated_int32.pop())
- self.assertEqual(0, m.repeated_int32.pop(0))
- self.assertEqual(2, m.repeated_int32.pop(1))
- self.assertEqual([1, 3], m.repeated_int32)
- def testRepeatedCompositeFieldPop(self, message_module):
- m = message_module.TestAllTypes()
- with self.assertRaises(IndexError) as _:
- m.repeated_nested_message.pop()
- with self.assertRaises(TypeError) as _:
- m.repeated_nested_message.pop('0')
- for i in range(5):
- n = m.repeated_nested_message.add()
- n.bb = i
- self.assertEqual(4, m.repeated_nested_message.pop().bb)
- self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
- self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
- self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
- def testRepeatedCompareWithSelf(self, message_module):
- m = message_module.TestAllTypes()
- for i in range(5):
- m.repeated_int32.insert(i, i)
- n = m.repeated_nested_message.add()
- n.bb = i
- self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
- self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
- def testReleasedNestedMessages(self, message_module):
- """A case that lead to a segfault when a message detached from its parent
- container has itself a child container.
- """
- m = message_module.NestedTestAllTypes()
- m = m.repeated_child.add()
- m = m.child
- m = m.repeated_child.add()
- self.assertEqual(m.payload.optional_int32, 0)
- def testSetRepeatedComposite(self, message_module):
- m = message_module.TestAllTypes()
- with self.assertRaises(AttributeError):
- m.repeated_int32 = []
- m.repeated_int32.append(1)
- with self.assertRaises(AttributeError):
- m.repeated_int32 = []
- def testReturningType(self, message_module):
- m = message_module.TestAllTypes()
- self.assertEqual(float, type(m.optional_float))
- self.assertEqual(float, type(m.optional_double))
- self.assertEqual(bool, type(m.optional_bool))
- m.optional_float = 1
- m.optional_double = 1
- m.optional_bool = 1
- m.repeated_float.append(1)
- m.repeated_double.append(1)
- m.repeated_bool.append(1)
- m.ParseFromString(m.SerializeToString())
- self.assertEqual(float, type(m.optional_float))
- self.assertEqual(float, type(m.optional_double))
- self.assertEqual('1.0', str(m.optional_double))
- self.assertEqual(bool, type(m.optional_bool))
- self.assertEqual(float, type(m.repeated_float[0]))
- self.assertEqual(float, type(m.repeated_double[0]))
- self.assertEqual(bool, type(m.repeated_bool[0]))
- self.assertEqual(True, m.repeated_bool[0])
- # Class to test proto2-only features (required, extensions, etc.)
- @testing_refleaks.TestCase
- class Proto2Test(unittest.TestCase):
- def testFieldPresence(self):
- message = unittest_pb2.TestAllTypes()
- self.assertFalse(message.HasField("optional_int32"))
- self.assertFalse(message.HasField("optional_bool"))
- self.assertFalse(message.HasField("optional_nested_message"))
- with self.assertRaises(ValueError):
- message.HasField("field_doesnt_exist")
- with self.assertRaises(ValueError):
- message.HasField("repeated_int32")
- with self.assertRaises(ValueError):
- message.HasField("repeated_nested_message")
- self.assertEqual(0, message.optional_int32)
- self.assertEqual(False, message.optional_bool)
- self.assertEqual(0, message.optional_nested_message.bb)
- # Fields are set even when setting the values to default values.
- message.optional_int32 = 0
- message.optional_bool = False
- message.optional_nested_message.bb = 0
- self.assertTrue(message.HasField("optional_int32"))
- self.assertTrue(message.HasField("optional_bool"))
- self.assertTrue(message.HasField("optional_nested_message"))
- # Set the fields to non-default values.
- message.optional_int32 = 5
- message.optional_bool = True
- message.optional_nested_message.bb = 15
- self.assertTrue(message.HasField(u"optional_int32"))
- self.assertTrue(message.HasField("optional_bool"))
- self.assertTrue(message.HasField("optional_nested_message"))
- # Clearing the fields unsets them and resets their value to default.
- message.ClearField("optional_int32")
- message.ClearField(u"optional_bool")
- message.ClearField("optional_nested_message")
- self.assertFalse(message.HasField("optional_int32"))
- self.assertFalse(message.HasField("optional_bool"))
- self.assertFalse(message.HasField("optional_nested_message"))
- self.assertEqual(0, message.optional_int32)
- self.assertEqual(False, message.optional_bool)
- self.assertEqual(0, message.optional_nested_message.bb)
- def testAssignInvalidEnum(self):
- """Assigning an invalid enum number is not allowed in proto2."""
- m = unittest_pb2.TestAllTypes()
- # Proto2 can not assign unknown enum.
- with self.assertRaises(ValueError) as _:
- m.optional_nested_enum = 1234567
- self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
- # Assignment is a different code path than append for the C++ impl.
- m.repeated_nested_enum.append(2)
- m.repeated_nested_enum[0] = 2
- with self.assertRaises(ValueError):
- m.repeated_nested_enum[0] = 123456
- # Unknown enum value can be parsed but is ignored.
- m2 = unittest_proto3_arena_pb2.TestAllTypes()
- m2.optional_nested_enum = 1234567
- m2.repeated_nested_enum.append(7654321)
- serialized = m2.SerializeToString()
- m3 = unittest_pb2.TestAllTypes()
- m3.ParseFromString(serialized)
- self.assertFalse(m3.HasField('optional_nested_enum'))
- # 1 is the default value for optional_nested_enum.
- self.assertEqual(1, m3.optional_nested_enum)
- self.assertEqual(0, len(m3.repeated_nested_enum))
- m2.Clear()
- m2.ParseFromString(m3.SerializeToString())
- self.assertEqual(1234567, m2.optional_nested_enum)
- self.assertEqual(7654321, m2.repeated_nested_enum[0])
- def testUnknownEnumMap(self):
- m = map_proto2_unittest_pb2.TestEnumMap()
- m.known_map_field[123] = 0
- with self.assertRaises(ValueError):
- m.unknown_map_field[1] = 123
- def testExtensionsErrors(self):
- msg = unittest_pb2.TestAllTypes()
- self.assertRaises(AttributeError, getattr, msg, 'Extensions')
- def testMergeFromExtensions(self):
- msg1 = more_extensions_pb2.TopLevelMessage()
- msg2 = more_extensions_pb2.TopLevelMessage()
- # Cpp extension will lazily create a sub message which is immutable.
- self.assertEqual(0, msg1.submessage.Extensions[
- more_extensions_pb2.optional_int_extension])
- self.assertFalse(msg1.HasField('submessage'))
- msg2.submessage.Extensions[
- more_extensions_pb2.optional_int_extension] = 123
- # Make sure cmessage and extensions pointing to a mutable message
- # after merge instead of the lazily created message.
- msg1.MergeFrom(msg2)
- self.assertEqual(123, msg1.submessage.Extensions[
- more_extensions_pb2.optional_int_extension])
- def testGoldenExtensions(self):
- golden_data = test_util.GoldenFileData('golden_message')
- golden_message = unittest_pb2.TestAllExtensions()
- golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(all_set)
- self.assertEqual(all_set, golden_message)
- self.assertEqual(golden_data, golden_message.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testGoldenPackedExtensions(self):
- golden_data = test_util.GoldenFileData('golden_packed_fields_message')
- golden_message = unittest_pb2.TestPackedExtensions()
- golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedExtensions()
- test_util.SetAllPackedExtensions(all_set)
- self.assertEqual(all_set, golden_message)
- self.assertEqual(golden_data, all_set.SerializeToString())
- golden_copy = copy.deepcopy(golden_message)
- self.assertEqual(golden_data, golden_copy.SerializeToString())
- def testPickleIncompleteProto(self):
- golden_message = unittest_pb2.TestRequired(a=1)
- pickled_message = pickle.dumps(golden_message)
- unpickled_message = pickle.loads(pickled_message)
- self.assertEqual(unpickled_message, golden_message)
- self.assertEqual(unpickled_message.a, 1)
- # This is still an incomplete proto - so serializing should fail
- self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
- # TODO(haberman): this isn't really a proto2-specific test except that this
- # message has a required field in it. Should probably be factored out so
- # that we can test the other parts with proto3.
- def testParsingMerge(self):
- """Check the merge behavior when a required or optional field appears
- multiple times in the input."""
- messages = [
- unittest_pb2.TestAllTypes(),
- unittest_pb2.TestAllTypes(),
- unittest_pb2.TestAllTypes() ]
- messages[0].optional_int32 = 1
- messages[1].optional_int64 = 2
- messages[2].optional_int32 = 3
- messages[2].optional_string = 'hello'
- merged_message = unittest_pb2.TestAllTypes()
- merged_message.optional_int32 = 3
- merged_message.optional_int64 = 2
- merged_message.optional_string = 'hello'
- generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
- generator.field1.extend(messages)
- generator.field2.extend(messages)
- generator.field3.extend(messages)
- generator.ext1.extend(messages)
- generator.ext2.extend(messages)
- generator.group1.add().field1.MergeFrom(messages[0])
- generator.group1.add().field1.MergeFrom(messages[1])
- generator.group1.add().field1.MergeFrom(messages[2])
- generator.group2.add().field1.MergeFrom(messages[0])
- generator.group2.add().field1.MergeFrom(messages[1])
- generator.group2.add().field1.MergeFrom(messages[2])
- data = generator.SerializeToString()
- parsing_merge = unittest_pb2.TestParsingMerge()
- parsing_merge.ParseFromString(data)
- # Required and optional fields should be merged.
- self.assertEqual(parsing_merge.required_all_types, merged_message)
- self.assertEqual(parsing_merge.optional_all_types, merged_message)
- self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
- merged_message)
- self.assertEqual(parsing_merge.Extensions[
- unittest_pb2.TestParsingMerge.optional_ext],
- merged_message)
- # Repeated fields should not be merged.
- self.assertEqual(len(parsing_merge.repeated_all_types), 3)
- self.assertEqual(len(parsing_merge.repeatedgroup), 3)
- self.assertEqual(len(parsing_merge.Extensions[
- unittest_pb2.TestParsingMerge.repeated_ext]), 3)
- def testPythonicInit(self):
- message = unittest_pb2.TestAllTypes(
- optional_int32=100,
- optional_fixed32=200,
- optional_float=300.5,
- optional_bytes=b'x',
- optionalgroup={'a': 400},
- optional_nested_message={'bb': 500},
- optional_foreign_message={},
- optional_nested_enum='BAZ',
- repeatedgroup=[{'a': 600},
- {'a': 700}],
- repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
- default_int32=800,
- oneof_string='y')
- self.assertIsInstance(message, unittest_pb2.TestAllTypes)
- self.assertEqual(100, message.optional_int32)
- self.assertEqual(200, message.optional_fixed32)
- self.assertEqual(300.5, message.optional_float)
- self.assertEqual(b'x', message.optional_bytes)
- self.assertEqual(400, message.optionalgroup.a)
- self.assertIsInstance(message.optional_nested_message,
- unittest_pb2.TestAllTypes.NestedMessage)
- self.assertEqual(500, message.optional_nested_message.bb)
- self.assertTrue(message.HasField('optional_foreign_message'))
- self.assertEqual(message.optional_foreign_message,
- unittest_pb2.ForeignMessage())
- self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
- message.optional_nested_enum)
- self.assertEqual(2, len(message.repeatedgroup))
- self.assertEqual(600, message.repeatedgroup[0].a)
- self.assertEqual(700, message.repeatedgroup[1].a)
- self.assertEqual(2, len(message.repeated_nested_enum))
- self.assertEqual(unittest_pb2.TestAllTypes.FOO,
- message.repeated_nested_enum[0])
- self.assertEqual(unittest_pb2.TestAllTypes.BAR,
- message.repeated_nested_enum[1])
- self.assertEqual(800, message.default_int32)
- self.assertEqual('y', message.oneof_string)
- self.assertFalse(message.HasField('optional_int64'))
- self.assertEqual(0, len(message.repeated_float))
- self.assertEqual(42, message.default_int64)
- message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
- self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
- message.optional_nested_enum)
- with self.assertRaises(ValueError):
- unittest_pb2.TestAllTypes(
- optional_nested_message={'INVALID_NESTED_FIELD': 17})
- with self.assertRaises(TypeError):
- unittest_pb2.TestAllTypes(
- optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
- with self.assertRaises(ValueError):
- unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
- with self.assertRaises(ValueError):
- unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
- def testPythonicInitWithDict(self):
- # Both string/unicode field name keys should work.
- kwargs = {
- 'optional_int32': 100,
- u'optional_fixed32': 200,
- }
- msg = unittest_pb2.TestAllTypes(**kwargs)
- self.assertEqual(100, msg.optional_int32)
- self.assertEqual(200, msg.optional_fixed32)
- def test_documentation(self):
- # Also used by the interactive help() function.
- doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
- self.assertIn('class TestAllTypes', doc)
- self.assertIn('SerializePartialToString', doc)
- self.assertIn('repeated_float', doc)
- base = unittest_pb2.TestAllTypes.__bases__[0]
- self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
- # Class to test proto3-only features/behavior (updated field presence & enums)
- @testing_refleaks.TestCase
- class Proto3Test(unittest.TestCase):
- # Utility method for comparing equality with a map.
- def assertMapIterEquals(self, map_iter, dict_value):
- # Avoid mutating caller's copy.
- dict_value = dict(dict_value)
- for k, v in map_iter:
- self.assertEqual(v, dict_value[k])
- del dict_value[k]
- self.assertEqual({}, dict_value)
- def testFieldPresence(self):
- message = unittest_proto3_arena_pb2.TestAllTypes()
- # We can't test presence of non-repeated, non-submessage fields.
- with self.assertRaises(ValueError):
- message.HasField('optional_int32')
- with self.assertRaises(ValueError):
- message.HasField('optional_float')
- with self.assertRaises(ValueError):
- message.HasField('optional_string')
- with self.assertRaises(ValueError):
- message.HasField('optional_bool')
- # But we can still test presence of submessage fields.
- self.assertFalse(message.HasField('optional_nested_message'))
- # As with proto2, we can't test presence of fields that don't exist, or
- # repeated fields.
- with self.assertRaises(ValueError):
- message.HasField('field_doesnt_exist')
- with self.assertRaises(ValueError):
- message.HasField('repeated_int32')
- with self.assertRaises(ValueError):
- message.HasField('repeated_nested_message')
- # Fields should default to their type-specific default.
- self.assertEqual(0, message.optional_int32)
- self.assertEqual(0, message.optional_float)
- self.assertEqual('', message.optional_string)
- self.assertEqual(False, message.optional_bool)
- self.assertEqual(0, message.optional_nested_message.bb)
- # Setting a submessage should still return proper presence information.
- message.optional_nested_message.bb = 0
- self.assertTrue(message.HasField('optional_nested_message'))
- # Set the fields to non-default values.
- message.optional_int32 = 5
- message.optional_float = 1.1
- message.optional_string = 'abc'
- message.optional_bool = True
- message.optional_nested_message.bb = 15
- # Clearing the fields unsets them and resets their value to default.
- message.ClearField('optional_int32')
- message.ClearField('optional_float')
- message.ClearField('optional_string')
- message.ClearField('optional_bool')
- message.ClearField('optional_nested_message')
- self.assertEqual(0, message.optional_int32)
- self.assertEqual(0, message.optional_float)
- self.assertEqual('', message.optional_string)
- self.assertEqual(False, message.optional_bool)
- self.assertEqual(0, message.optional_nested_message.bb)
- def testProto3ParserDropDefaultScalar(self):
- message_proto2 = unittest_pb2.TestAllTypes()
- message_proto2.optional_int32 = 0
- message_proto2.optional_string = ''
- message_proto2.optional_bytes = b''
- self.assertEqual(len(message_proto2.ListFields()), 3)
- message_proto3 = unittest_proto3_arena_pb2.TestAllTypes()
- message_proto3.ParseFromString(message_proto2.SerializeToString())
- self.assertEqual(len(message_proto3.ListFields()), 0)
- def testProto3Optional(self):
- msg = test_proto3_optional_pb2.TestProto3Optional()
- self.assertFalse(msg.HasField('optional_int32'))
- self.assertFalse(msg.HasField('optional_float'))
- self.assertFalse(msg.HasField('optional_string'))
- self.assertFalse(msg.HasField('optional_nested_message'))
- self.assertFalse(msg.optional_nested_message.HasField('bb'))
- # Set fields.
- msg.optional_int32 = 1
- msg.optional_float = 1.0
- msg.optional_string = '123'
- msg.optional_nested_message.bb = 1
- self.assertTrue(msg.HasField('optional_int32'))
- self.assertTrue(msg.HasField('optional_float'))
- self.assertTrue(msg.HasField('optional_string'))
- self.assertTrue(msg.HasField('optional_nested_message'))
- self.assertTrue(msg.optional_nested_message.HasField('bb'))
- # Set to default value does not clear the fields
- msg.optional_int32 = 0
- msg.optional_float = 0.0
- msg.optional_string = ''
- msg.optional_nested_message.bb = 0
- self.assertTrue(msg.HasField('optional_int32'))
- self.assertTrue(msg.HasField('optional_float'))
- self.assertTrue(msg.HasField('optional_string'))
- self.assertTrue(msg.HasField('optional_nested_message'))
- self.assertTrue(msg.optional_nested_message.HasField('bb'))
- # Test serialize
- msg2 = test_proto3_optional_pb2.TestProto3Optional()
- msg2.ParseFromString(msg.SerializeToString())
- self.assertTrue(msg2.HasField('optional_int32'))
- self.assertTrue(msg2.HasField('optional_float'))
- self.assertTrue(msg2.HasField('optional_string'))
- self.assertTrue(msg2.HasField('optional_nested_message'))
- self.assertTrue(msg2.optional_nested_message.HasField('bb'))
- self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32')
- # Clear these fields.
- msg.ClearField('optional_int32')
- msg.ClearField('optional_float')
- msg.ClearField('optional_string')
- msg.ClearField('optional_nested_message')
- self.assertFalse(msg.HasField('optional_int32'))
- self.assertFalse(msg.HasField('optional_float'))
- self.assertFalse(msg.HasField('optional_string'))
- self.assertFalse(msg.HasField('optional_nested_message'))
- self.assertFalse(msg.optional_nested_message.HasField('bb'))
- self.assertEqual(msg.WhichOneof('_optional_int32'), None)
- def testAssignUnknownEnum(self):
- """Assigning an unknown enum value is allowed and preserves the value."""
- m = unittest_proto3_arena_pb2.TestAllTypes()
- # Proto3 can assign unknown enums.
- m.optional_nested_enum = 1234567
- self.assertEqual(1234567, m.optional_nested_enum)
- m.repeated_nested_enum.append(22334455)
- self.assertEqual(22334455, m.repeated_nested_enum[0])
- # Assignment is a different code path than append for the C++ impl.
- m.repeated_nested_enum[0] = 7654321
- self.assertEqual(7654321, m.repeated_nested_enum[0])
- serialized = m.SerializeToString()
- m2 = unittest_proto3_arena_pb2.TestAllTypes()
- m2.ParseFromString(serialized)
- self.assertEqual(1234567, m2.optional_nested_enum)
- self.assertEqual(7654321, m2.repeated_nested_enum[0])
- # Map isn't really a proto3-only feature. But there is no proto2 equivalent
- # of google/protobuf/map_unittest.proto right now, so it's not easy to
- # test both with the same test like we do for the other proto2/proto3 tests.
- # (google/protobuf/map_proto2_unittest.proto is very different in the set
- # of messages and fields it contains).
- def testScalarMapDefaults(self):
- msg = map_unittest_pb2.TestMap()
- # Scalars start out unset.
- self.assertFalse(-123 in msg.map_int32_int32)
- self.assertFalse(-2**33 in msg.map_int64_int64)
- self.assertFalse(123 in msg.map_uint32_uint32)
- self.assertFalse(2**33 in msg.map_uint64_uint64)
- self.assertFalse(123 in msg.map_int32_double)
- self.assertFalse(False in msg.map_bool_bool)
- self.assertFalse('abc' in msg.map_string_string)
- self.assertFalse(111 in msg.map_int32_bytes)
- self.assertFalse(888 in msg.map_int32_enum)
- # Accessing an unset key returns the default.
- self.assertEqual(0, msg.map_int32_int32[-123])
- self.assertEqual(0, msg.map_int64_int64[-2**33])
- self.assertEqual(0, msg.map_uint32_uint32[123])
- self.assertEqual(0, msg.map_uint64_uint64[2**33])
- self.assertEqual(0.0, msg.map_int32_double[123])
- self.assertTrue(isinstance(msg.map_int32_double[123], float))
- self.assertEqual(False, msg.map_bool_bool[False])
- self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
- self.assertEqual('', msg.map_string_string['abc'])
- self.assertEqual(b'', msg.map_int32_bytes[111])
- self.assertEqual(0, msg.map_int32_enum[888])
- # It also sets the value in the map
- self.assertTrue(-123 in msg.map_int32_int32)
- self.assertTrue(-2**33 in msg.map_int64_int64)
- self.assertTrue(123 in msg.map_uint32_uint32)
- self.assertTrue(2**33 in msg.map_uint64_uint64)
- self.assertTrue(123 in msg.map_int32_double)
- self.assertTrue(False in msg.map_bool_bool)
- self.assertTrue('abc' in msg.map_string_string)
- self.assertTrue(111 in msg.map_int32_bytes)
- self.assertTrue(888 in msg.map_int32_enum)
- self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
- # Accessing an unset key still throws TypeError if the type of the key
- # is incorrect.
- with self.assertRaises(TypeError):
- msg.map_string_string[123]
- with self.assertRaises(TypeError):
- 123 in msg.map_string_string
- def testMapGet(self):
- # Need to test that get() properly returns the default, even though the dict
- # has defaultdict-like semantics.
- msg = map_unittest_pb2.TestMap()
- self.assertIsNone(msg.map_int32_int32.get(5))
- self.assertEqual(10, msg.map_int32_int32.get(5, 10))
- self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10))
- self.assertIsNone(msg.map_int32_int32.get(5))
- msg.map_int32_int32[5] = 15
- self.assertEqual(15, msg.map_int32_int32.get(5))
- self.assertEqual(15, msg.map_int32_int32.get(5))
- with self.assertRaises(TypeError):
- msg.map_int32_int32.get('')
- self.assertIsNone(msg.map_int32_foreign_message.get(5))
- self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
- self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10))
- submsg = msg.map_int32_foreign_message[5]
- self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
- with self.assertRaises(TypeError):
- msg.map_int32_foreign_message.get('')
- def testScalarMap(self):
- msg = map_unittest_pb2.TestMap()
- self.assertEqual(0, len(msg.map_int32_int32))
- self.assertFalse(5 in msg.map_int32_int32)
- msg.map_int32_int32[-123] = -456
- msg.map_int64_int64[-2**33] = -2**34
- msg.map_uint32_uint32[123] = 456
- msg.map_uint64_uint64[2**33] = 2**34
- msg.map_int32_float[2] = 1.2
- msg.map_int32_double[1] = 3.3
- msg.map_string_string['abc'] = '123'
- msg.map_bool_bool[True] = True
- msg.map_int32_enum[888] = 2
- # Unknown numeric enum is supported in proto3.
- msg.map_int32_enum[123] = 456
- self.assertEqual([], msg.FindInitializationErrors())
- self.assertEqual(1, len(msg.map_string_string))
- # Bad key.
- with self.assertRaises(TypeError):
- msg.map_string_string[123] = '123'
- # Verify that trying to assign a bad key doesn't actually add a member to
- # the map.
- self.assertEqual(1, len(msg.map_string_string))
- # Bad value.
- with self.assertRaises(TypeError):
- msg.map_string_string['123'] = 123
- serialized = msg.SerializeToString()
- msg2 = map_unittest_pb2.TestMap()
- msg2.ParseFromString(serialized)
- # Bad key.
- with self.assertRaises(TypeError):
- msg2.map_string_string[123] = '123'
- # Bad value.
- with self.assertRaises(TypeError):
- msg2.map_string_string['123'] = 123
- self.assertEqual(-456, msg2.map_int32_int32[-123])
- self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
- self.assertEqual(456, msg2.map_uint32_uint32[123])
- self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
- self.assertAlmostEqual(1.2, msg.map_int32_float[2])
- self.assertEqual(3.3, msg.map_int32_double[1])
- self.assertEqual('123', msg2.map_string_string['abc'])
- self.assertEqual(True, msg2.map_bool_bool[True])
- self.assertEqual(2, msg2.map_int32_enum[888])
- self.assertEqual(456, msg2.map_int32_enum[123])
- self.assertEqual('{-123: -456}',
- str(msg2.map_int32_int32))
- def testMapEntryAlwaysSerialized(self):
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_int32[0] = 0
- msg.map_string_string[''] = ''
- self.assertEqual(msg.ByteSize(), 12)
- self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
- msg.SerializeToString())
- def testStringUnicodeConversionInMap(self):
- msg = map_unittest_pb2.TestMap()
- unicode_obj = u'\u1234'
- bytes_obj = unicode_obj.encode('utf8')
- msg.map_string_string[bytes_obj] = bytes_obj
- (key, value) = list(msg.map_string_string.items())[0]
- self.assertEqual(key, unicode_obj)
- self.assertEqual(value, unicode_obj)
- self.assertIsInstance(key, six.text_type)
- self.assertIsInstance(value, six.text_type)
- def testMessageMap(self):
- msg = map_unittest_pb2.TestMap()
- self.assertEqual(0, len(msg.map_int32_foreign_message))
- self.assertFalse(5 in msg.map_int32_foreign_message)
- msg.map_int32_foreign_message[123]
- # get_or_create() is an alias for getitem.
- msg.map_int32_foreign_message.get_or_create(-456)
- self.assertEqual(2, len(msg.map_int32_foreign_message))
- self.assertIn(123, msg.map_int32_foreign_message)
- self.assertIn(-456, msg.map_int32_foreign_message)
- self.assertEqual(2, len(msg.map_int32_foreign_message))
- # Bad key.
- with self.assertRaises(TypeError):
- msg.map_int32_foreign_message['123']
- # Can't assign directly to submessage.
- with self.assertRaises(ValueError):
- msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
- # Verify that trying to assign a bad key doesn't actually add a member to
- # the map.
- self.assertEqual(2, len(msg.map_int32_foreign_message))
- serialized = msg.SerializeToString()
- msg2 = map_unittest_pb2.TestMap()
- msg2.ParseFromString(serialized)
- self.assertEqual(2, len(msg2.map_int32_foreign_message))
- self.assertIn(123, msg2.map_int32_foreign_message)
- self.assertIn(-456, msg2.map_int32_foreign_message)
- self.assertEqual(2, len(msg2.map_int32_foreign_message))
- msg2.map_int32_foreign_message[123].c = 1
- # TODO(jieluo): Fix text format for message map.
- self.assertIn(str(msg2.map_int32_foreign_message),
- ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
- def testNestedMessageMapItemDelete(self):
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_all_types[1].optional_nested_message.bb = 1
- del msg.map_int32_all_types[1]
- msg.map_int32_all_types[2].optional_nested_message.bb = 2
- self.assertEqual(1, len(msg.map_int32_all_types))
- msg.map_int32_all_types[1].optional_nested_message.bb = 1
- self.assertEqual(2, len(msg.map_int32_all_types))
- serialized = msg.SerializeToString()
- msg2 = map_unittest_pb2.TestMap()
- msg2.ParseFromString(serialized)
- keys = [1, 2]
- # The loop triggers PyErr_Occurred() in c extension.
- for key in keys:
- del msg2.map_int32_all_types[key]
- def testMapByteSize(self):
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_int32[1] = 1
- size = msg.ByteSize()
- msg.map_int32_int32[1] = 128
- self.assertEqual(msg.ByteSize(), size + 1)
- msg.map_int32_foreign_message[19].c = 1
- size = msg.ByteSize()
- msg.map_int32_foreign_message[19].c = 128
- self.assertEqual(msg.ByteSize(), size + 1)
- def testMergeFrom(self):
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_int32[12] = 34
- msg.map_int32_int32[56] = 78
- msg.map_int64_int64[22] = 33
- msg.map_int32_foreign_message[111].c = 5
- msg.map_int32_foreign_message[222].c = 10
- msg2 = map_unittest_pb2.TestMap()
- msg2.map_int32_int32[12] = 55
- msg2.map_int64_int64[88] = 99
- msg2.map_int32_foreign_message[222].c = 15
- msg2.map_int32_foreign_message[222].d = 20
- old_map_value = msg2.map_int32_foreign_message[222]
- msg2.MergeFrom(msg)
- # Compare with expected message instead of call
- # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not
- # sync with repeated field and there is no duplicated keys.
- expected_msg = map_unittest_pb2.TestMap()
- expected_msg.CopyFrom(msg)
- expected_msg.map_int64_int64[88] = 99
- self.assertEqual(msg2, expected_msg)
- self.assertEqual(34, msg2.map_int32_int32[12])
- self.assertEqual(78, msg2.map_int32_int32[56])
- self.assertEqual(33, msg2.map_int64_int64[22])
- self.assertEqual(99, msg2.map_int64_int64[88])
- self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
- self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
- self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
- if api_implementation.Type() != 'cpp':
- # During the call to MergeFrom(), the C++ implementation will have
- # deallocated the underlying message, but this is very difficult to detect
- # properly. The line below is likely to cause a segmentation fault.
- # With the Python implementation, old_map_value is just 'detached' from
- # the main message. Using it will not crash of course, but since it still
- # have a reference to the parent message I'm sure we can find interesting
- # ways to cause inconsistencies.
- self.assertEqual(15, old_map_value.c)
- # Verify that there is only one entry per key, even though the MergeFrom
- # may have internally created multiple entries for a single key in the
- # list representation.
- as_dict = {}
- for key in msg2.map_int32_foreign_message:
- self.assertFalse(key in as_dict)
- as_dict[key] = msg2.map_int32_foreign_message[key].c
- self.assertEqual({111: 5, 222: 10}, as_dict)
- # Special case: test that delete of item really removes the item, even if
- # there might have physically been duplicate keys due to the previous merge.
- # This is only a special case for the C++ implementation which stores the
- # map as an array.
- del msg2.map_int32_int32[12]
- self.assertFalse(12 in msg2.map_int32_int32)
- del msg2.map_int32_foreign_message[222]
- self.assertFalse(222 in msg2.map_int32_foreign_message)
- with self.assertRaises(TypeError):
- del msg2.map_int32_foreign_message['']
- def testMapMergeFrom(self):
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_int32[12] = 34
- msg.map_int32_int32[56] = 78
- msg.map_int64_int64[22] = 33
- msg.map_int32_foreign_message[111].c = 5
- msg.map_int32_foreign_message[222].c = 10
- msg2 = map_unittest_pb2.TestMap()
- msg2.map_int32_int32[12] = 55
- msg2.map_int64_int64[88] = 99
- msg2.map_int32_foreign_message[222].c = 15
- msg2.map_int32_foreign_message[222].d = 20
- msg2.map_int32_int32.MergeFrom(msg.map_int32_int32)
- self.assertEqual(34, msg2.map_int32_int32[12])
- self.assertEqual(78, msg2.map_int32_int32[56])
- msg2.map_int64_int64.MergeFrom(msg.map_int64_int64)
- self.assertEqual(33, msg2.map_int64_int64[22])
- self.assertEqual(99, msg2.map_int64_int64[88])
- msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message)
- # Compare with expected message instead of call
- # msg.map_int32_foreign_message[222] to make sure MergeFrom does not
- # sync with repeated field and no duplicated keys.
- expected_msg = map_unittest_pb2.TestMap()
- expected_msg.CopyFrom(msg)
- expected_msg.map_int64_int64[88] = 99
- self.assertEqual(msg2, expected_msg)
- # Test when cpp extension cache a map.
- m1 = map_unittest_pb2.TestMap()
- m2 = map_unittest_pb2.TestMap()
- self.assertEqual(m1.map_int32_foreign_message,
- m1.map_int32_foreign_message)
- m2.map_int32_foreign_message[123].c = 10
- m1.MergeFrom(m2)
- self.assertEqual(10, m2.map_int32_foreign_message[123].c)
- # Test merge maps within different message types.
- m1 = map_unittest_pb2.TestMap()
- m2 = map_unittest_pb2.TestMessageMap()
- m2.map_int32_message[123].optional_int32 = 10
- m1.map_int32_all_types.MergeFrom(m2.map_int32_message)
- self.assertEqual(10, m1.map_int32_all_types[123].optional_int32)
- # Test overwrite message value map
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_foreign_message[222].c = 123
- msg2 = map_unittest_pb2.TestMap()
- msg2.map_int32_foreign_message[222].d = 20
- msg.MergeFromString(msg2.SerializeToString())
- self.assertEqual(msg.map_int32_foreign_message[222].d, 20)
- self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123)
- # Merge a dict to map field is not accepted
- with self.assertRaises(AttributeError):
- m1.map_int32_all_types.MergeFrom(
- {1: unittest_proto3_arena_pb2.TestAllTypes()})
- def testMergeFromBadType(self):
- msg = map_unittest_pb2.TestMap()
- with self.assertRaisesRegexp(
- TypeError,
- r'Parameter to MergeFrom\(\) must be instance of same class: expected '
- r'.*TestMap got int\.'):
- msg.MergeFrom(1)
- def testCopyFromBadType(self):
- msg = map_unittest_pb2.TestMap()
- with self.assertRaisesRegexp(
- TypeError,
- r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
- r'expected .*TestMap got int\.'):
- msg.CopyFrom(1)
- def testIntegerMapWithLongs(self):
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_int32[long(-123)] = long(-456)
- msg.map_int64_int64[long(-2**33)] = long(-2**34)
- msg.map_uint32_uint32[long(123)] = long(456)
- msg.map_uint64_uint64[long(2**33)] = long(2**34)
- serialized = msg.SerializeToString()
- msg2 = map_unittest_pb2.TestMap()
- msg2.ParseFromString(serialized)
- self.assertEqual(-456, msg2.map_int32_int32[-123])
- self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
- self.assertEqual(456, msg2.map_uint32_uint32[123])
- self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
- def testMapAssignmentCausesPresence(self):
- msg = map_unittest_pb2.TestMapSubmessage()
- msg.test_map.map_int32_int32[123] = 456
- serialized = msg.SerializeToString()
- msg2 = map_unittest_pb2.TestMapSubmessage()
- msg2.ParseFromString(serialized)
- self.assertEqual(msg, msg2)
- # Now test that various mutations of the map properly invalidate the
- # cached size of the submessage.
- msg.test_map.map_int32_int32[888] = 999
- serialized = msg.SerializeToString()
- msg2.ParseFromString(serialized)
- self.assertEqual(msg, msg2)
- msg.test_map.map_int32_int32.clear()
- serialized = msg.SerializeToString()
- msg2.ParseFromString(serialized)
- self.assertEqual(msg, msg2)
- def testMapAssignmentCausesPresenceForSubmessages(self):
- msg = map_unittest_pb2.TestMapSubmessage()
- msg.test_map.map_int32_foreign_message[123].c = 5
- serialized = msg.SerializeToString()
- msg2 = map_unittest_pb2.TestMapSubmessage()
- msg2.ParseFromString(serialized)
- self.assertEqual(msg, msg2)
- # Now test that various mutations of the map properly invalidate the
- # cached size of the submessage.
- msg.test_map.map_int32_foreign_message[888].c = 7
- serialized = msg.SerializeToString()
- msg2.ParseFromString(serialized)
- self.assertEqual(msg, msg2)
- msg.test_map.map_int32_foreign_message[888].MergeFrom(
- msg.test_map.map_int32_foreign_message[123])
- serialized = msg.SerializeToString()
- msg2.ParseFromString(serialized)
- self.assertEqual(msg, msg2)
- msg.test_map.map_int32_foreign_message.clear()
- serialized = msg.SerializeToString()
- msg2.ParseFromString(serialized)
- self.assertEqual(msg, msg2)
- def testModifyMapWhileIterating(self):
- msg = map_unittest_pb2.TestMap()
- string_string_iter = iter(msg.map_string_string)
- int32_foreign_iter = iter(msg.map_int32_foreign_message)
- msg.map_string_string['abc'] = '123'
- msg.map_int32_foreign_message[5].c = 5
- with self.assertRaises(RuntimeError):
- for key in string_string_iter:
- pass
- with self.assertRaises(RuntimeError):
- for key in int32_foreign_iter:
- pass
- def testSubmessageMap(self):
- msg = map_unittest_pb2.TestMap()
- submsg = msg.map_int32_foreign_message[111]
- self.assertIs(submsg, msg.map_int32_foreign_message[111])
- self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
- submsg.c = 5
- serialized = msg.SerializeToString()
- msg2 = map_unittest_pb2.TestMap()
- msg2.ParseFromString(serialized)
- self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
- # Doesn't allow direct submessage assignment.
- with self.assertRaises(ValueError):
- msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
- def testMapIteration(self):
- msg = map_unittest_pb2.TestMap()
- for k, v in msg.map_int32_int32.items():
- # Should not be reached.
- self.assertTrue(False)
- msg.map_int32_int32[2] = 4
- msg.map_int32_int32[3] = 6
- msg.map_int32_int32[4] = 8
- self.assertEqual(3, len(msg.map_int32_int32))
- matching_dict = {2: 4, 3: 6, 4: 8}
- self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
- def testPython2Map(self):
- if sys.version_info < (3,):
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_int32[2] = 4
- msg.map_int32_int32[3] = 6
- msg.map_int32_int32[4] = 8
- msg.map_int32_int32[5] = 10
- map_int32 = msg.map_int32_int32
- self.assertEqual(4, len(map_int32))
- msg2 = map_unittest_pb2.TestMap()
- msg2.ParseFromString(msg.SerializeToString())
- def CheckItems(seq, iterator):
- self.assertEqual(next(iterator), seq[0])
- self.assertEqual(list(iterator), seq[1:])
- CheckItems(map_int32.items(), map_int32.iteritems())
- CheckItems(map_int32.keys(), map_int32.iterkeys())
- CheckItems(map_int32.values(), map_int32.itervalues())
- self.assertEqual(6, map_int32.get(3))
- self.assertEqual(None, map_int32.get(999))
- self.assertEqual(6, map_int32.pop(3))
- self.assertEqual(0, map_int32.pop(3))
- self.assertEqual(3, len(map_int32))
- key, value = map_int32.popitem()
- self.assertEqual(2 * key, value)
- self.assertEqual(2, len(map_int32))
- map_int32.clear()
- self.assertEqual(0, len(map_int32))
- with self.assertRaises(KeyError):
- map_int32.popitem()
- self.assertEqual(0, map_int32.setdefault(2))
- self.assertEqual(1, len(map_int32))
- map_int32.update(msg2.map_int32_int32)
- self.assertEqual(4, len(map_int32))
- with self.assertRaises(TypeError):
- map_int32.update(msg2.map_int32_int32,
- msg2.map_int32_int32)
- with self.assertRaises(TypeError):
- map_int32.update(0)
- with self.assertRaises(TypeError):
- map_int32.update(value=12)
- def testMapItems(self):
- # Map items used to have strange behaviors when use c extension. Because
- # [] may reorder the map and invalidate any exsting iterators.
- # TODO(jieluo): Check if [] reordering the map is a bug or intended
- # behavior.
- msg = map_unittest_pb2.TestMap()
- msg.map_string_string['local_init_op'] = ''
- msg.map_string_string['trainable_variables'] = ''
- msg.map_string_string['variables'] = ''
- msg.map_string_string['init_op'] = ''
- msg.map_string_string['summaries'] = ''
- items1 = msg.map_string_string.items()
- items2 = msg.map_string_string.items()
- self.assertEqual(items1, items2)
- def testMapDeterministicSerialization(self):
- golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
- b'r\n\n\x05item1\x12\x01e'
- b'r\n\n\x05item2\x12\x01f'
- b'r\n\n\x05item3\x12\x01g'
- b'r\x0b\n\x05item4\x12\x02QQ'
- b'r\x12\n\rlocal_init_op\x12\x01a'
- b'r\x0e\n\tsummaries\x12\x01e'
- b'r\x18\n\x13trainable_variables\x12\x01b'
- b'r\x0e\n\tvariables\x12\x01c')
- msg = map_unittest_pb2.TestMap()
- msg.map_string_string['local_init_op'] = 'a'
- msg.map_string_string['trainable_variables'] = 'b'
- msg.map_string_string['variables'] = 'c'
- msg.map_string_string['init_op'] = 'd'
- msg.map_string_string['summaries'] = 'e'
- msg.map_string_string['item1'] = 'e'
- msg.map_string_string['item2'] = 'f'
- msg.map_string_string['item3'] = 'g'
- msg.map_string_string['item4'] = 'QQ'
- # If deterministic serialization is not working correctly, this will be
- # "flaky" depending on the exact python dict hash seed.
- #
- # Fortunately, there are enough items in this map that it is extremely
- # unlikely to ever hit the "right" in-order combination, so the test
- # itself should fail reliably.
- self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
- def testMapIterationClearMessage(self):
- # Iterator needs to work even if message and map are deleted.
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_int32[2] = 4
- msg.map_int32_int32[3] = 6
- msg.map_int32_int32[4] = 8
- it = msg.map_int32_int32.items()
- del msg
- matching_dict = {2: 4, 3: 6, 4: 8}
- self.assertMapIterEquals(it, matching_dict)
- def testMapConstruction(self):
- msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
- self.assertEqual(2, msg.map_int32_int32[1])
- self.assertEqual(4, msg.map_int32_int32[3])
- msg = map_unittest_pb2.TestMap(
- map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
- self.assertEqual(5, msg.map_int32_foreign_message[3].c)
- def testMapScalarFieldConstruction(self):
- msg1 = map_unittest_pb2.TestMap()
- msg1.map_int32_int32[1] = 42
- msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32)
- self.assertEqual(42, msg2.map_int32_int32[1])
- def testMapMessageFieldConstruction(self):
- msg1 = map_unittest_pb2.TestMap()
- msg1.map_string_foreign_message['test'].c = 42
- msg2 = map_unittest_pb2.TestMap(
- map_string_foreign_message=msg1.map_string_foreign_message)
- self.assertEqual(42, msg2.map_string_foreign_message['test'].c)
- def testMapFieldRaisesCorrectError(self):
- # Should raise a TypeError when given a non-iterable.
- with self.assertRaises(TypeError):
- map_unittest_pb2.TestMap(map_string_foreign_message=1)
- def testMapValidAfterFieldCleared(self):
- # Map needs to work even if field is cleared.
- # For the C++ implementation this tests the correctness of
- # MapContainer::Release()
- msg = map_unittest_pb2.TestMap()
- int32_map = msg.map_int32_int32
- int32_map[2] = 4
- int32_map[3] = 6
- int32_map[4] = 8
- msg.ClearField('map_int32_int32')
- self.assertEqual(b'', msg.SerializeToString())
- matching_dict = {2: 4, 3: 6, 4: 8}
- self.assertMapIterEquals(int32_map.items(), matching_dict)
- def testMessageMapValidAfterFieldCleared(self):
- # Map needs to work even if field is cleared.
- # For the C++ implementation this tests the correctness of
- # MapContainer::Release()
- msg = map_unittest_pb2.TestMap()
- int32_foreign_message = msg.map_int32_foreign_message
- int32_foreign_message[2].c = 5
- msg.ClearField('map_int32_foreign_message')
- self.assertEqual(b'', msg.SerializeToString())
- self.assertTrue(2 in int32_foreign_message.keys())
- def testMessageMapItemValidAfterTopMessageCleared(self):
- # Message map item needs to work even if it is cleared.
- # For the C++ implementation this tests the correctness of
- # MapContainer::Release()
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_all_types[2].optional_string = 'bar'
- if api_implementation.Type() == 'cpp':
- # Need to keep the map reference because of b/27942626.
- # TODO(jieluo): Remove it.
- unused_map = msg.map_int32_all_types # pylint: disable=unused-variable
- msg_value = msg.map_int32_all_types[2]
- msg.Clear()
- # Reset to trigger sync between repeated field and map in c++.
- msg.map_int32_all_types[3].optional_string = 'foo'
- self.assertEqual(msg_value.optional_string, 'bar')
- def testMapIterInvalidatedByClearField(self):
- # Map iterator is invalidated when field is cleared.
- # But this case does need to not crash the interpreter.
- # For the C++ implementation this tests the correctness of
- # ScalarMapContainer::Release()
- msg = map_unittest_pb2.TestMap()
- it = iter(msg.map_int32_int32)
- msg.ClearField('map_int32_int32')
- with self.assertRaises(RuntimeError):
- for _ in it:
- pass
- it = iter(msg.map_int32_foreign_message)
- msg.ClearField('map_int32_foreign_message')
- with self.assertRaises(RuntimeError):
- for _ in it:
- pass
- def testMapDelete(self):
- msg = map_unittest_pb2.TestMap()
- self.assertEqual(0, len(msg.map_int32_int32))
- msg.map_int32_int32[4] = 6
- self.assertEqual(1, len(msg.map_int32_int32))
- with self.assertRaises(KeyError):
- del msg.map_int32_int32[88]
- del msg.map_int32_int32[4]
- self.assertEqual(0, len(msg.map_int32_int32))
- with self.assertRaises(KeyError):
- del msg.map_int32_all_types[32]
- def testMapsAreMapping(self):
- msg = map_unittest_pb2.TestMap()
- self.assertIsInstance(msg.map_int32_int32, collections_abc.Mapping)
- self.assertIsInstance(msg.map_int32_int32, collections_abc.MutableMapping)
- self.assertIsInstance(msg.map_int32_foreign_message, collections_abc.Mapping)
- self.assertIsInstance(msg.map_int32_foreign_message,
- collections_abc.MutableMapping)
- def testMapsCompare(self):
- msg = map_unittest_pb2.TestMap()
- msg.map_int32_int32[-123] = -456
- self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
- self.assertEqual(msg.map_int32_foreign_message,
- msg.map_int32_foreign_message)
- self.assertNotEqual(msg.map_int32_int32, 0)
- def testMapFindInitializationErrorsSmokeTest(self):
- msg = map_unittest_pb2.TestMap()
- msg.map_string_string['abc'] = '123'
- msg.map_int32_int32[35] = 64
- msg.map_string_foreign_message['foo'].c = 5
- self.assertEqual(0, len(msg.FindInitializationErrors()))
- @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2')
- def testStrictUtf8Check(self):
- # Test u'\ud801' is rejected at parser in both python2 and python3.
- serialized = (b'r\x03\xed\xa0\x81')
- msg = unittest_proto3_arena_pb2.TestAllTypes()
- with self.assertRaises(Exception) as context:
- msg.MergeFromString(serialized)
- if api_implementation.Type() == 'python':
- self.assertIn('optional_string', str(context.exception))
- else:
- self.assertIn('Error parsing message', str(context.exception))
- # Test optional_string=u'😍' is accepted.
- serialized = unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=u'😍').SerializeToString()
- msg2 = unittest_proto3_arena_pb2.TestAllTypes()
- msg2.MergeFromString(serialized)
- self.assertEqual(msg2.optional_string, u'😍')
- msg = unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=u'\ud001')
- self.assertEqual(msg.optional_string, u'\ud001')
- @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2')
- def testSurrogatesInPython3(self):
- # Surrogates like U+D83D is an invalid unicode character, it is
- # supported by Python2 only because in some builds, unicode strings
- # use 2-bytes code units. Since Python 3.3, we don't have this problem.
- #
- # Surrogates are utf16 code units, in a unicode string they are invalid
- # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf
- # Python3 reject such cases at setters and parsers. Python2 accpect it
- # to keep same features with the language itself. 'Unpaired pairs'
- # like u'\ud801' are rejected at parsers when strict utf8 check is enabled
- # in proto3 to keep same behavior with c extension.
- # Surrogates are rejected at setters in Python3.
- with self.assertRaises(ValueError):
- unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=u'\ud801\udc01')
- with self.assertRaises(ValueError):
- unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=b'\xed\xa0\x81')
- with self.assertRaises(ValueError):
- unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=u'\ud801')
- with self.assertRaises(ValueError):
- unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=u'\ud801\ud801')
- @unittest.skipIf(six.PY3 or sys.maxunicode == UCS2_MAXUNICODE,
- 'Surrogates are rejected at setters in Python3')
- def testSurrogatesInPython2(self):
- # Test optional_string=u'\ud801\udc01'.
- # surrogate pair is acceptable in python2.
- msg = unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=u'\ud801\udc01')
- # TODO(jieluo): Change pure python to have same behavior with c extension.
- # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are
- # equal, some are not equal.
- if api_implementation.Type() == 'python':
- self.assertEqual(msg.optional_string, u'\ud801\udc01')
- else:
- self.assertEqual(msg.optional_string, u'\U00010401')
- serialized = msg.SerializeToString()
- msg2 = unittest_proto3_arena_pb2.TestAllTypes()
- msg2.MergeFromString(serialized)
- self.assertEqual(msg2.optional_string, u'\U00010401')
- # Python2 does not reject surrogates at setters.
- msg = unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=b'\xed\xa0\x81')
- unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=u'\ud801')
- unittest_proto3_arena_pb2.TestAllTypes(
- optional_string=u'\ud801\ud801')
- @testing_refleaks.TestCase
- class ValidTypeNamesTest(unittest.TestCase):
- def assertImportFromName(self, msg, base_name):
- # Parse <type 'module.class_name'> to extra 'some.name' as a string.
- tp_name = str(type(msg)).split("'")[1]
- valid_names = ('Repeated%sContainer' % base_name,
- 'Repeated%sFieldContainer' % base_name)
- self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
- '%r does end with any of %r' % (tp_name, valid_names))
- parts = tp_name.split('.')
- class_name = parts[-1]
- module_name = '.'.join(parts[:-1])
- __import__(module_name, fromlist=[class_name])
- def testTypeNamesCanBeImported(self):
- # If import doesn't work, pickling won't work either.
- pb = unittest_pb2.TestAllTypes()
- self.assertImportFromName(pb.repeated_int32, 'Scalar')
- self.assertImportFromName(pb.repeated_nested_message, 'Composite')
- @testing_refleaks.TestCase
- class PackedFieldTest(unittest.TestCase):
- def setMessage(self, message):
- message.repeated_int32.append(1)
- message.repeated_int64.append(1)
- message.repeated_uint32.append(1)
- message.repeated_uint64.append(1)
- message.repeated_sint32.append(1)
- message.repeated_sint64.append(1)
- message.repeated_fixed32.append(1)
- message.repeated_fixed64.append(1)
- message.repeated_sfixed32.append(1)
- message.repeated_sfixed64.append(1)
- message.repeated_float.append(1.0)
- message.repeated_double.append(1.0)
- message.repeated_bool.append(True)
- message.repeated_nested_enum.append(1)
- def testPackedFields(self):
- message = packed_field_test_pb2.TestPackedTypes()
- self.setMessage(message)
- golden_data = (b'\x0A\x01\x01'
- b'\x12\x01\x01'
- b'\x1A\x01\x01'
- b'\x22\x01\x01'
- b'\x2A\x01\x02'
- b'\x32\x01\x02'
- b'\x3A\x04\x01\x00\x00\x00'
- b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
- b'\x4A\x04\x01\x00\x00\x00'
- b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
- b'\x5A\x04\x00\x00\x80\x3f'
- b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
- b'\x6A\x01\x01'
- b'\x72\x01\x01')
- self.assertEqual(golden_data, message.SerializeToString())
- def testUnpackedFields(self):
- message = packed_field_test_pb2.TestUnpackedTypes()
- self.setMessage(message)
- golden_data = (b'\x08\x01'
- b'\x10\x01'
- b'\x18\x01'
- b'\x20\x01'
- b'\x28\x02'
- b'\x30\x02'
- b'\x3D\x01\x00\x00\x00'
- b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
- b'\x4D\x01\x00\x00\x00'
- b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
- b'\x5D\x00\x00\x80\x3f'
- b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
- b'\x68\x01'
- b'\x70\x01')
- self.assertEqual(golden_data, message.SerializeToString())
- @unittest.skipIf(api_implementation.Type() != 'cpp' or
- sys.version_info < (2, 7),
- 'explicit tests of the C++ implementation for PY27 and above')
- @testing_refleaks.TestCase
- class OversizeProtosTest(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- # At the moment, reference cycles between DescriptorPool and Message classes
- # are not detected and these objects are never freed.
- # To avoid errors with ReferenceLeakChecker, we create the class only once.
- file_desc = """
- name: "f/f.msg2"
- package: "f"
- message_type {
- name: "msg1"
- field {
- name: "payload"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- }
- message_type {
- name: "msg2"
- field {
- name: "field"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: "msg1"
- }
- }
- """
- pool = descriptor_pool.DescriptorPool()
- desc = descriptor_pb2.FileDescriptorProto()
- text_format.Parse(file_desc, desc)
- pool.Add(desc)
- cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
- pool.FindMessageTypeByName('f.msg2'))
- def setUp(self):
- self.p = self.proto_cls()
- self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
- self.p_serialized = self.p.SerializeToString()
- def testAssertOversizeProto(self):
- from google.protobuf.pyext._message import SetAllowOversizeProtos
- SetAllowOversizeProtos(False)
- q = self.proto_cls()
- try:
- q.ParseFromString(self.p_serialized)
- except message.DecodeError as e:
- self.assertEqual(str(e), 'Error parsing message')
- def testSucceedOversizeProto(self):
- from google.protobuf.pyext._message import SetAllowOversizeProtos
- SetAllowOversizeProtos(True)
- q = self.proto_cls()
- q.ParseFromString(self.p_serialized)
- self.assertEqual(self.p.field.payload, q.field.payload)
- if __name__ == '__main__':
- unittest.main()
|