歡迎來到Linux教程網
Linux教程網
Linux教程網
Linux教程網
Linux教程網 >> Linux基礎 >> 關於Linux >> 使用自己的Python函數處理Protobuf中的字符串編碼

使用自己的Python函數處理Protobuf中的字符串編碼

日期:2017/3/1 11:42:47   编辑:關於Linux

我目前所在的項目是一個老項目,裡面的字符串編碼有點亂,數據庫中有些是GB2312,有些是UTF8;代碼中有些是GBK,有些是UTF8,代碼中轉來轉去,經常是不太清楚當前這個字符串是什麼編碼,由於是老項目,也沒去修改。最近合服腳本由項目上進行維護了,我拿到腳本看了看是Python寫的,我之前也沒學習過Python,只有現學現用。

數據庫中使用了Protobuf,這裡面也有字符串,編碼也是有GBK,也有UTF8編碼的,而且是交叉使用,有過合服經驗的同學應該知道,這裡會涉及一些修改,比如名字沖突需要改名。Protobuf中的名字修改就需要先解析出來修改了再序列化回去。這個時候問題來了,Protobuf默認是使用的UTF8編碼進行解析(Decode)與序列化的(Encode),可以參見:google.protobuf.internal中的decoder.py中的函數:

def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
  """Returns a decoder for a string field."""

  local_DecodeVarint = _DecodeVarint
  local_unicode = unicode

  assert not is_packed
  if is_repeated:
    tag_bytes = encoder.TagBytes(field_number,
                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
    tag_len = len(tag_bytes)
    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
      value = field_dict.get(key)
      if value is None:
        value = field_dict.setdefault(key, new_default(message))
      while 1:
        (size, pos) = local_DecodeVarint(buffer, pos)
        new_pos = pos + size
        if new_pos > end:
          raise _DecodeError('Truncated string.')
        value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
        # Predict that the next tag is another copy of the same repeated field.
        pos = new_pos + tag_len
        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
          # Prediction failed.  Return.
          return new_pos
    return DecodeRepeatedField
  else:
    def DecodeField(buffer, pos, end, message, field_dict):
      (size, pos) = local_DecodeVarint(buffer, pos)
      new_pos = pos + size
      if new_pos > end:
        raise _DecodeError('Truncated string.')
      field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
      return new_pos
    return DecodeField

以及encoder.py中的函數

def StringEncoder(field_number, is_repeated, is_packed):
  """Returns an encoder for a string field."""

  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
  local_EncodeVarint = _EncodeVarint
  local_len = len
  assert not is_packed
  if is_repeated:
    def EncodeRepeatedField(write, value):
      for element in value:
        encoded = element.encode('utf-8')
        write(tag)
        local_EncodeVarint(write, local_len(encoded))
        write(encoded)
    return EncodeRepeatedField
  else:
    def EncodeField(write, value):
      encoded = value.encode('utf-8')
      write(tag)
      local_EncodeVarint(write, local_len(encoded))
      return write(encoded)
    return EncodeField

如果Protobuf中的字符串編碼為非UTF8編碼,則在解析(Decode)的過程中會出現異常(有點奇怪的是我同事的電腦上沒出現異常):

'utf8' codec can't decode byte……

我們有沒有一個方法在不改變Protobuf原來的代碼的情況下使用自己的函數來進行解析呢,這是我首先想到的,由於沒學習過Python,惡補了一下Python基礎後,研究發現Protobuf是把Decode的函數入口放在了一個數組中,在引入模塊的時候就會自動初始化這些入口函數,然後保存到各個Protobuf類中,各個PB類都有一個decoders_by_tag字典,這個字典就存放了各種數據類型的解析函數入口地址。

通過上面的代碼可以看出,具體解析函數(DecodeField)是放在一個閉包中的,不能直接修改,所以必須整個(StringDecoder)替換。通過深入研究,終於發現了其設置的入口,在google.protobuf.internal的type_checkers.py中有這樣一段代碼:

# Maps from field types to encoder constructors.
TYPE_TO_ENCODER = {
    _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
    _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
    _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
    _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
    _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
    _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
    _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
    _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
    _FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
    _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
    _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
    _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
    _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
    _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
    _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
    _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
    _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
    _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
    }


# Maps from field types to sizer constructors.
TYPE_TO_SIZER = {
    _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
    _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
    _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
    _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
    _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
    _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
    _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
    _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
    _FieldDescriptor.TYPE_STRING: encoder.StringSizer,
    _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
    _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
    _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
    _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
    _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
    _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
    _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
    _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
    _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
    }


# Maps from field type to a decoder constructor.
TYPE_TO_DECODER = {
    _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
    _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
    _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
    _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
    _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
    _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
    _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
    _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
    _FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
    _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
    _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
    _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
    _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
    _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
    _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
    _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
    _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
    _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
    }

第一個是序列化(Encoder)的函數入口,第二個是計算大小的函數入口,第三個就是解析(Decoder)的入口,我們可以看到這裡映射了所有類型的處理函數入口,那我們把這個入口函數替換成我們自己的函數,就可以根據實際需要進行處理了。

這裡我們需要特別注意的是Protobuf中的各個類都是在模塊導入的時候就初始化好了,所以,如果我們要修改入口函數,必須在PB各類引入之前進行修改。為此我寫了一個模塊文件:protobuf_hack.py,這個模塊必須先於PB類import,其內容如下:

from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf.internal import type_checkers
from google.protobuf import reflection
from google.protobuf import message

def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
  """Returns a decoder for a string field."""

  local_DecodeVarint = _DecodeVarint
  local_unicode = unicode

  assert not is_packed
  if is_repeated:
    tag_bytes = encoder.TagBytes(field_number,
                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
    tag_len = len(tag_bytes)
    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
      value = field_dict.get(key)
      if value is None:
        value = field_dict.setdefault(key, new_default(message))
      while 1:
        (size, pos) = local_DecodeVarint(buffer, pos)
        new_pos = pos + size
        if new_pos > end:
          raise _DecodeError('Truncated string.')
        value.append(local_unicode(buffer[pos:new_pos], 'gbk'))
        # Predict that the next tag is another copy of the same repeated field.
        pos = new_pos + tag_len
        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
          # Prediction failed.  Return.
          return new_pos
    return DecodeRepeatedField
  else:
    def DecodeField(buffer, pos, end, message, field_dict):
      (size, pos) = local_DecodeVarint(buffer, pos)
      new_pos = pos + size
      if new_pos > end:
        raise _DecodeError('Truncated string.')
      field_dict[key] = local_unicode(buffer[pos:new_pos], 'gbk')
      return new_pos
    return DecodeField

type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringDecoder 


這樣,我們可以把所有PB中的字符串解析按GBK編碼解析了。但是項目中的字符串並不是所有的字符串都是GBK編碼的,也有UTF8編碼的,為了支持兩種編碼,我做了一個處理,就是先嘗試使用一種編碼解析,如果出現異常,再使用另一種編碼進行解析,這樣就保證了我們所有的字符串都可以正確解析。理想很豐滿,現實很骨感,解析是正確了,但是如果我們序列化回去在服務器程序中去使用的時候就會出現亂碼,因為原來的GBK或者UTF8統一成UTF8編碼了,當然,我們也可以繼續像Decoder調用自己的函數一樣處理Encoder,但是在Encoder中我們並不知道這個字符串原來在數據庫中是什麼編碼,也沒有PB以及字段信息,無法差別處理。

至此,算是白忙活了,無法滿足需要。

如果我們能夠只修改我們指定的PB類的處理函數就好了,因為我們可以找出哪些PB的字符串是GBK編碼的。再次經過深入研究,總算是做到了。

在這裡有一個函數幫了我大忙,reflection.py中的ParseMessage函數,我們看一下:

def ParseMessage(descriptor, byte_str):
  """Generate a new Message instance from this Descriptor and a byte string.

  Args:
    descriptor: Protobuf Descriptor object
    byte_str: Serialized protocol buffer byte string

  Returns:
    Newly created protobuf Message object.
  """

  class _ResultClass(message.Message):
    __metaclass__ = GeneratedProtocolMessageType
    DESCRIPTOR = descriptor

  new_msg = _ResultClass()
  new_msg.ParseFromString(byte_str)
  return new_msg


這個函數其實就是通過描述符信息(descriptor)來解析二進制串,生成一個新的PB消息實例。這中間的關鍵就是函數中的那個動態生成類實例的代碼,在這裡會走一次PB類的初始化流程,即會初始化我們所需要的Decoder以及Encoder函數映射字典。為了工作需要,我修改一下這個函數:

def ParseMessage(descriptor):
  class _ResultClass(message.Message):
    __metaclass__ = reflection.GeneratedProtocolMessageType
    DESCRIPTOR = descriptor

  new_msg = _ResultClass()
  return new_msg

然後加入我們需要使用自定義函數處理的PB類,注意這裡一定是所需要的最小的PB結構。
def hacker(msg):
    ParseMessage(msg.DESCRIPTOR)
	
def hack_pb():
    #修改默認的字符串處理函數入口為自定義函數
    type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringDecoder
    type_checkers.TYPE_TO_ENCODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringEncoder
    type_checkers.TYPE_TO_SIZER[type_checkers._FieldDescriptor.TYPE_STRING] = StringSizer

    try:
        # 這裡加入我們需要修改的PB類
        hacker(DbProto.DB_FriendAssetEntry_PB)
    except Exception as e:
        print(e)

    #還原字符串處理函數入口
    type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = decoder.StringDecoder
    type_checkers.TYPE_TO_ENCODER[type_checkers._FieldDescriptor.TYPE_STRING] = encoder.StringEncoder
    type_checkers.TYPE_TO_SIZER[type_checkers._FieldDescriptor.TYPE_STRING] = encoder.StringSizer

由於Encode的時候Protobuf是先計算字段的長度,然後再處理的各字段,所以我們還需要把計算大小的函數使用自定義函數,否則再次解析會出問題。

現在基本上滿足了需要,算是大功告成了!

細心的讀者,不知你發現沒,這裡還是有一個問題,目前無法解決的問題,就是如果我們一個最小的PB中如果有兩個字符串字段,采用的不同的編碼怎麼辦?一般情況下,正常的設計者不會這樣做,但是就像我們項目中的編碼混亂一樣,如果一個不小心就搞成不一樣的編碼就悲劇了!如果哪位高手有此解決方案,歡迎分享!!!

把整個文件附上:

from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf.internal import type_checkers
from google.protobuf import reflection
from google.protobuf import message

def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
    """Returns a decoder for a string field."""

    local_DecodeVarint = decoder._DecodeVarint
    local_unicode = unicode

    assert not is_packed
    if is_repeated:
        tag_bytes = encoder.TagBytes(field_number,
                                     wire_format.WIRETYPE_LENGTH_DELIMITED)
        tag_len = len(tag_bytes)

        def DecodeRepeatedField(buffer, pos, end, message, field_dict):
            value = field_dict.get(key)
            if value is None:
                value = field_dict.setdefault(key, new_default(message))
            while 1:
                (size, pos) = local_DecodeVarint(buffer, pos)
                new_pos = pos + size
                if new_pos > end:
                    raise decoder._DecodeError('Truncated string.')
                str = '' #這裡先嘗試使用UTF8編碼進行解析,如果出現異常則嘗試使用GBK編碼解析
                try:
                    str = local_unicode(buffer[pos:new_pos], 'utf-8')
                except Exception as e:
                    try:
                        str = local_unicode(buffer[pos:new_pos], 'gbk')
                    except Exception as e1:
                        str = ''

                value.append(str)
                # Predict that the next tag is another copy of the same repeated field.
                pos = new_pos + tag_len
                if buffer[new_pos:pos] != tag_bytes or new_pos == end:
                    # Prediction failed.  Return.
                    return new_pos

        return DecodeRepeatedField
    else:
        def DecodeField(buffer, pos, end, message, field_dict):
            (size, pos) = local_DecodeVarint(buffer, pos)
            new_pos = pos + size
            if new_pos > end:
                raise decoder._DecodeError('Truncated string.')

            str = '' #這裡先嘗試使用UTF8編碼進行解析,如果出現異常則嘗試使用GBK編碼解析
            try:
                str = local_unicode(buffer[pos:new_pos], 'utf-8')
            except Exception as e:
                try:
                    str = local_unicode(buffer[pos:new_pos], 'gbk')
                except Exception as e1:
                    str = ''

            field_dict[key] = str
            return new_pos

        return DecodeField


def StringEncoder(field_number, is_repeated, is_packed):
    """Returns an encoder for a string field."""

    tag = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
    local_EncodeVarint = encoder._EncodeVarint
    local_len = len
    assert not is_packed
    if is_repeated:
        def EncodeRepeatedField(write, value):
            for element in value:
                encoded = element.encode('gbk') #序列化的時候就直接使用GBK編碼了
                write(tag)
                local_EncodeVarint(write, local_len(encoded))
                write(encoded)

        return EncodeRepeatedField
    else:
        def EncodeField(write, value):
            encoded = value.encode('gbk') #序列化的時候就直接使用GBK編碼了
            write(tag)
            local_EncodeVarint(write, local_len(encoded))
            return write(encoded)

        return EncodeField

def StringSizer(field_number, is_repeated, is_packed):
    """Returns a sizer for a string field."""

    tag_size = encoder._TagSize(field_number)
    local_VarintSize = encoder._VarintSize
    local_len = len
    assert not is_packed
    if is_repeated:
        def RepeatedFieldSize(value):
            result = tag_size * len(value)
            for element in value:
                l = local_len(element.encode('gbk')) #注意序列化前計算長度時也需要使用與序列化相同的編碼,否則會出錯
                result += local_VarintSize(l) + l
            return result

        return RepeatedFieldSize
    else:
        def FieldSize(value):
            l = local_len(value.encode('gbk')) #注意序列化前計算長度時也需要使用與序列化相同的編碼,否則會出錯
            return tag_size + local_VarintSize(l) + l

        return FieldSize

def ParseMessage(descriptor):
  class _ResultClass(message.Message):
    __metaclass__ = reflection.GeneratedProtocolMessageType
    DESCRIPTOR = descriptor

  new_msg = _ResultClass()
  return new_msg

def hacker(msg):
    ParseMessage(msg.DESCRIPTOR)

def hack_pb():
    # 修改默認的字符串處理函數入口為自定義函數
    type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringDecoder
    type_checkers.TYPE_TO_ENCODER[type_checkers._FieldDescriptor.TYPE_STRING] = StringEncoder
    type_checkers.TYPE_TO_SIZER[type_checkers._FieldDescriptor.TYPE_STRING] = StringSizer

    try:
        # 這裡加入我們需要修改的PB類,注意這裡需要自行import DbProto模塊
        hacker(DbProto.DB_FriendAssetEntry_PB)
    except Exception as e:
        print(e)

    # 還原字符串處理函數入口
    type_checkers.TYPE_TO_DECODER[type_checkers._FieldDescriptor.TYPE_STRING] = decoder.StringDecoder
    type_checkers.TYPE_TO_ENCODER[type_checkers._FieldDescriptor.TYPE_STRING] = encoder.StringEncoder
    type_checkers.TYPE_TO_SIZER[type_checkers._FieldDescriptor.TYPE_STRING] = encoder.StringSizer

#這裡讓其在引入模塊時自動執行
hack_pb()
Copyright © Linux教程網 All Rights Reserved