| // Copyright 2010-2015, Google Inc. |
| // All rights reserved. |
| // |
| // Redistribution and use in source and binary forms, with or without |
| // modification, are permitted provided that the following conditions are |
| // met: |
| // |
| // * Redistributions of source code must retain the above copyright |
| // notice, this list of conditions and the following disclaimer. |
| // * Redistributions in binary form must reproduce the above |
| // copyright notice, this list of conditions and the following disclaimer |
| // in the documentation and/or other materials provided with the |
| // distribution. |
| // * Neither the name of Google Inc. nor the names of its |
| // contributors may be used to endorse or promote products derived from |
| // this software without specific prior written permission. |
| // |
| // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| |
| #include <iostream> // NOLINT |
| #include <map> |
| #include <numeric> // accumulate |
| #include <string> |
| #include <vector> |
| |
| #include "base/file_stream.h" |
| #include "base/flags.h" |
| #include "base/logging.h" |
| #include "base/multifile.h" |
| #include "base/port.h" |
| #include "base/util.h" |
| #include "client/client.h" |
| #include "evaluation/scorer.h" |
| #include "session/commands.pb.h" |
| |
| // Test data automatically generated by gen_client_quality_test_data.py |
| // TestCase test_cases[] is defined. |
| #include "client/client_quality_test_data.h" |
| |
| DEFINE_string(server_path, "", "specify server path"); |
| DEFINE_string(log_path, "", "specify log output file path"); |
| DEFINE_int32(max_case_for_source, 500, |
| "specify max test case number for each test sources"); |
| |
| namespace mozc { |
| bool IsValidSourceSentence(const string &str) { |
| // TODO(noriyukit) Treat alphabets by changing to Eisu-mode |
| if (Util::ContainsScriptType(str, Util::ALPHABET)) { |
| LOG(WARNING) << "contains ALPHABET: " << str; |
| return false; |
| } |
| |
| // Source should not contain kanji |
| if (Util::ContainsScriptType(str, Util::KANJI)) { |
| LOG(WARNING) << "contains KANJI: " << str; |
| return false; |
| } |
| |
| // Source should not contain katakana |
| string tmp, tmp2; |
| Util::StringReplace(str, "\xE3\x83\xBC", "", true, &tmp); // "ー" -> "" |
| Util::StringReplace(tmp, "\xE3\x83\xBB", "", true, &tmp2); // "・" -> "" |
| if (Util::ContainsScriptType(tmp2, Util::KATAKANA)) { |
| LOG(WARNING) << "contain KATAKANA: " << str; |
| return false; |
| } |
| return true; |
| } |
| |
| bool GenerateKeySequenceFrom(const string& hiragana_sentence, |
| vector<commands::KeyEvent>* keys) { |
| CHECK(keys); |
| keys->clear(); |
| |
| string tmp, input; |
| Util::HiraganaToRomanji(hiragana_sentence, &tmp); |
| Util::FullWidthToHalfWidth(tmp, &input); |
| |
| for (ConstChar32Iterator iter(input); !iter.Done(); iter.Next()) { |
| const char32 ucs4 = iter.Get(); |
| |
| // TODO(noriyukit) Improve key sequence generation; currently, a few ucs4 |
| // codes, like FF5E and 300E, cannot be handled. |
| commands::KeyEvent key; |
| if (ucs4 >= 0x20 && ucs4 <= 0x7F) { |
| key.set_key_code(static_cast<int>(ucs4)); |
| } else if (ucs4 == 0x3001 || ucs4 == 0xFF64) { |
| key.set_key_code(0x002C); // Full-width comma -> Half-width comma |
| } else if (ucs4 == 0x3002 || ucs4 == 0xFF0E || ucs4 == 0xFF61) { |
| key.set_key_code(0x002E); // Full-width period -> Half-width period |
| } else if (ucs4 == 0x2212 || ucs4 == 0x2015) { |
| key.set_key_code(0x002D); // "−" -> "-" |
| } else if (ucs4 == 0x300C || ucs4 == 0xff62) { |
| key.set_key_code(0x005B); // "「" -> "[" |
| } else if (ucs4 == 0x300D || ucs4 == 0xff63) { |
| key.set_key_code(0x005D); // "」" -> "]" |
| } else if (ucs4 == 0x30FB || ucs4 == 0xFF65) { |
| key.set_key_code(0x002F); // "・" -> "/" "・" -> "/" |
| } else { |
| LOG(WARNING) << "Unexpected character: " << hex << ucs4 |
| << ": in " << input << " (" << hiragana_sentence << ")"; |
| return false; |
| } |
| keys->push_back(key); |
| } |
| |
| // Conversion key |
| { |
| commands::KeyEvent key; |
| key.set_special_key(commands::KeyEvent::SPACE); |
| keys->push_back(key); |
| } |
| return true; |
| } |
| |
| bool GetPreedit(const commands::Output &output, string* str) { |
| CHECK(str); |
| |
| if (!output.has_preedit()) { |
| LOG(WARNING) << "No result"; |
| return false; |
| } |
| |
| str->clear(); |
| for (size_t i = 0; i < output.preedit().segment_size(); ++i) { |
| str->append(output.preedit().segment(i).value()); |
| } |
| |
| return true; |
| } |
| |
| bool CalculateBLEU(client::Client* client, |
| const string& hiragana_sentence, |
| const string& expected_result, double* score) { |
| // Prepare key events |
| vector<commands::KeyEvent> keys; |
| if (!GenerateKeySequenceFrom(hiragana_sentence, &keys)) { |
| LOG(WARNING) << "Failed to generated key events from: " |
| << hiragana_sentence; |
| return false; |
| } |
| |
| // Must send ON first |
| commands::Output output; |
| { |
| commands::KeyEvent key; |
| key.set_special_key(commands::KeyEvent::ON); |
| client->SendKey(key, &output); |
| } |
| |
| // Send keys |
| for (size_t i = 0; i < keys.size(); ++i) { |
| client->SendKey(keys[i], &output); |
| } |
| VLOG(2) << "Server response: " << output.Utf8DebugString(); |
| |
| // Calculate score |
| string expected_normalized; |
| Scorer::NormalizeForEvaluate(expected_result, &expected_normalized); |
| vector<string> goldens; |
| goldens.push_back(expected_normalized); |
| string preedit, preedit_normalized; |
| if (!GetPreedit(output, &preedit) || preedit.empty()) { |
| LOG(WARNING) << "Could not get output"; |
| return false; |
| } |
| Scorer::NormalizeForEvaluate(preedit, &preedit_normalized); |
| |
| *score = Scorer::BLEUScore(goldens, preedit_normalized); |
| |
| VLOG(1) << hiragana_sentence << endl |
| << " score: " << (*score) << endl |
| << " preedit: " << preedit_normalized << endl |
| << "expected: " << expected_normalized; |
| |
| // Revert session to prevent server from learning this conversion |
| commands::SessionCommand command; |
| command.set_type(commands::SessionCommand::REVERT); |
| client->SendCommand(command, &output); |
| |
| return true; |
| } |
| |
| double CalculateMean(const vector<double>& scores) { |
| CHECK(!scores.empty()); |
| const double sum = accumulate(scores.begin(), scores.end(), 0.0); |
| return sum / static_cast<double>(scores.size()); |
| } |
| } // namespace mozc |
| |
| |
| int main(int argc, char* argv[]) { |
| InitGoogle(argv[0], &argc, &argv, true); |
| |
| mozc::client::Client client; |
| if (!FLAGS_server_path.empty()) { |
| client.set_server_program(FLAGS_server_path); |
| } |
| |
| CHECK(client.IsValidRunLevel()) << "IsValidRunLevel failed"; |
| CHECK(client.EnsureSession()) << "EnsureSession failed"; |
| CHECK(client.NoOperation()) << "Server is not respoinding"; |
| |
| map<string, vector<double> > scores; // Results to be averaged |
| |
| for (mozc::TestCase* test_case = mozc::test_cases; test_case->source != NULL; |
| ++test_case) { |
| const string &source = test_case->source; |
| const string &hiragana_sentence = test_case->hiragana_sentence; |
| const string &expected_result = test_case->expected_result; |
| |
| if (scores.find(source) == scores.end()) { |
| scores[source] = vector<double>(); |
| } |
| if (scores[source].size() >= FLAGS_max_case_for_source) { |
| continue; |
| } |
| |
| VLOG(1) << "Processing " << hiragana_sentence; |
| if (!mozc::IsValidSourceSentence(hiragana_sentence)) { |
| LOG(WARNING) << "Invalid test case: " << endl |
| << " source: " << source << endl |
| << " hiragana: " << hiragana_sentence << endl |
| << " expected: " << expected_result; |
| continue; |
| } |
| |
| double score; |
| if (!mozc::CalculateBLEU(&client, hiragana_sentence, |
| expected_result, &score)) { |
| LOG(WARNING) << "Failed to calculate BLEU score: " << endl |
| << " source: " << source << endl |
| << " hiragana: " << hiragana_sentence << endl |
| << " expected: " << expected_result; |
| continue; |
| } |
| scores[source].push_back(score); |
| } |
| |
| ostream *ofs = &cout; |
| if (!FLAGS_log_path.empty()) { |
| ofs = new mozc::OutputFileStream(FLAGS_log_path.c_str()); |
| } |
| |
| // Average the scores |
| for (map<string, vector<double> >::iterator it = scores.begin(); |
| it != scores.end(); ++it) { |
| const double mean = mozc::CalculateMean(it->second); |
| (*ofs) << it->first << " : " << mean << endl; |
| } |
| if (ofs != &cout) { |
| delete ofs; |
| } |
| |
| return 0; |
| } |