Protocol Buffersで大きなモデルを読み込む

最近,モデルの読み書きには Protocol Buffersを使っている.

複雑な構造を持つデータを扱ってそれを読み書きするようなプログラムを書くときに一番ダルいI/Oを書かなくて済むのと,バイナリで保存してくれるので記憶容量も少なくて済むところが良い.

普通に使う限りだと,

void read_model(const std::string& model_file, ::google::protobuf::Message* msg) {
  ifstream is(model_file.c_str(), ios::binary);
  if (!msg->ParseFromIstream(&is)) {
    // エラー処理
  }
}

void write_model(const ::google::protobuf::Message* msg, const std::string& model_file) {
  ofstream os(model_file.c_str(), ios_base::binary | ios_base::trunc);
  if (!msg->SerializeToOstream(&os)) {
    // エラー処理
  }
}

てな感じに書くだけでI/Oが出来てしまうので素晴らしい.

普通に使う分には素晴らしいProtocol Buffersであるが,扱えるファイル(message)のサイズに制限があり,デフォルトでは32MBを超えた時点で警告,64MBを超えた時点でエラーとなってしまう.

わたしが最近遊んでいるような音声認識のモデルだと,楽勝でこのサイズを超えてしまうので工夫が必要になる.

サイズの制限を外す方法は

google/protobuf/io/coded_stream.h

void SetTotalBytesLimit(int total_bytes_limit, int warning_threshold);

の宣言の上辺りに以下のように書いてある.

  // Hint:  If you are reading this because your program is printing a                                                                                                                                    
  //   warning about dangerously large protocol messages, you may be                                                                                                                                      
  //   confused about what to do next.  The best option is to change your                                                                                                                                 
  //   design such that excessively large messages are not necessary.                                                                                                                                     
  //   For example, try to design file formats to consist of many small                                                                                                                                   
  //   messages rather than a single large one.  If this is infeasible,                                                                                                                                   
  //   you will need to increase the limit.  Chances are, though, that                                                                                                                                    
  //   your code never constructs a CodedInputStream on which the limit                                                                                                                                   
  //   can be set.  You probably parse messages by calling things like                                                                                                                                    
  //   Message::ParseFromString().  In this case, you will need to change                                                                                                                                 
  //   your code to instead construct some sort of ZeroCopyInputStream                                                                                                                                    
  //   (e.g. an ArrayInputStream), construct a CodedInputStream around                                                                                                                                    
  //   that, then call Message::ParseFromCodedStream() instead.  Then                                                                                                                                     
  //   you can adjust the limit.  Yes, it's more work, but you're doing                                                                                                                                   
  //   something unusual.      

要は普通じゃない使い方だけど,どうしてもというなら方法はあるということだが,色々調べるのがダルかったのでまとめておきます.

以下の様に書くと上手くいった.

void read_model(const std::string& model_file, ::google::protobuf::Message* msg) {
  namespace gpio = ::google::protobuf::io;
  int fd = open(model_file.c_str(), O_RDONLY);
  gpio::ZeroCopyInputStream* raw_input = new gpio::FileInputStream(fd);
  gpio::CodedInputStream* coded_input = new gpio::CodedInputStream(raw_input);
  coded_input->SetTotalBytesLimit(INT_MAX, INT_MAX);
  if (!msg->ParseFromCodedStream(coded_input)) {
     // エラー処理
  }
  delete coded_input;
  delete raw_input;
  close(fd);
}

書き込みの方は特に変更は必要なかった.

あんまり必要になる人はいないだろうが..