| // Copyright 2010-2014, Google Inc. |
| // All rights reserved. |
| // |
| // Redistribution and use in source and binary forms, with or without |
| // modification, are permitted provided that the following conditions are |
| // met: |
| // |
| // * Redistributions of source code must retain the above copyright |
| // notice, this list of conditions and the following disclaimer. |
| // * Redistributions in binary form must reproduce the above |
| // copyright notice, this list of conditions and the following disclaimer |
| // in the documentation and/or other materials provided with the |
| // distribution. |
| // * Neither the name of Google Inc. nor the names of its |
| // contributors may be used to endorse or promote products derived from |
| // this software without specific prior written permission. |
| // |
| // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| |
| #include "converter/nbest_generator.h" |
| |
| #include <algorithm> |
| #include <string> |
| #include <vector> |
| |
| #include "base/logging.h" |
| #include "base/util.h" |
| #include "converter/candidate_filter.h" |
| #include "converter/connector_interface.h" |
| #include "converter/lattice.h" |
| #include "converter/node.h" |
| #include "converter/segmenter_interface.h" |
| #include "converter/segments.h" |
| #include "dictionary/pos_matcher.h" |
| |
| namespace mozc { |
| namespace { |
| |
| const int kFreeListSize = 512; |
| const int kCostDiff = 3453; // log prob of 1/1000 |
| |
| } // namespace |
| |
| using converter::CandidateFilter; |
| |
| struct NBestGenerator::QueueElement { |
| const Node *node; |
| const QueueElement *next; |
| int32 fx; // f(x) = h(x) + g(x): cost function for A* search |
| int32 gx; // g(x) |
| // transition cost part of g(x). |
| // Do not take the transition costs to edge nodes. |
| int32 structure_gx; |
| int32 w_gx; |
| }; |
| |
| const NBestGenerator::QueueElement *NBestGenerator::CreateNewElement( |
| const Node *node, |
| const QueueElement *next, |
| int32 fx, |
| int32 gx, |
| int32 structure_gx, |
| int32 w_gx) { |
| QueueElement *elm = freelist_.Alloc(); |
| DCHECK(elm); |
| elm->node = node; |
| elm->next = next; |
| elm->fx = fx; |
| elm->gx = gx; |
| elm->structure_gx = structure_gx; |
| elm->w_gx = w_gx; |
| return elm; |
| } |
| |
| struct NBestGenerator::QueueElementComparator { |
| bool operator()(const NBestGenerator::QueueElement *q1, |
| const NBestGenerator::QueueElement *q2) const { |
| return (q1->fx > q2->fx); |
| } |
| }; |
| |
| inline void NBestGenerator::Agenda::Push( |
| const NBestGenerator::QueueElement *element) { |
| priority_queue_.push_back(element); |
| push_heap(priority_queue_.begin(), priority_queue_.end(), |
| QueueElementComparator()); |
| } |
| |
| inline void NBestGenerator::Agenda::Pop() { |
| DCHECK(!priority_queue_.empty()); |
| pop_heap(priority_queue_.begin(), priority_queue_.end(), |
| QueueElementComparator()); |
| priority_queue_.pop_back(); |
| } |
| |
| NBestGenerator::NBestGenerator(const SuppressionDictionary *suppression_dic, |
| const SegmenterInterface *segmenter, |
| const ConnectorInterface *connector, |
| const POSMatcher *pos_matcher, |
| const Lattice *lattice, |
| const SuggestionFilter *suggestion_filter) |
| : suppression_dictionary_(suppression_dic), |
| segmenter_(segmenter), connector_(connector), pos_matcher_(pos_matcher), |
| lattice_(lattice), |
| begin_node_(NULL), end_node_(NULL), |
| freelist_(kFreeListSize), |
| filter_(new CandidateFilter( |
| suppression_dic, pos_matcher, suggestion_filter)), |
| viterbi_result_checked_(false), |
| check_mode_(STRICT), |
| boundary_checker_(NULL) { |
| DCHECK(suppression_dictionary_); |
| DCHECK(segmenter); |
| DCHECK(connector); |
| if (lattice_ == NULL || !lattice_->has_lattice()) { |
| LOG(ERROR) << "lattice is not available"; |
| return; |
| } |
| |
| agenda_.Reserve(kFreeListSize); |
| } |
| |
| NBestGenerator::~NBestGenerator() { |
| } |
| |
| void NBestGenerator::Reset(const Node *begin_node, const Node *end_node, |
| const BoundaryCheckMode mode) { |
| agenda_.Clear(); |
| freelist_.Free(); |
| filter_->Reset(); |
| viterbi_result_checked_ = false; |
| check_mode_ = mode; |
| |
| begin_node_ = begin_node; |
| end_node_ = end_node; |
| |
| for (Node *node = lattice_->begin_nodes(end_node_->begin_pos); |
| node != NULL; node = node->bnext) { |
| if (node == end_node_ || |
| (node->lid != end_node_->lid && |
| node->cost - end_node_->cost <= kCostDiff && |
| node->prev != end_node_->prev)) { |
| // Push "EOS" nodes. |
| agenda_.Push(CreateNewElement(node, NULL, node->cost, 0, 0, 0)); |
| } |
| } |
| |
| switch (check_mode_) { |
| case STRICT: |
| boundary_checker_ = &NBestGenerator::CheckStrict; |
| break; |
| case ONLY_MID: |
| boundary_checker_ = &NBestGenerator::CheckOnlyMid; |
| break; |
| case ONLY_EDGE: |
| boundary_checker_ = &NBestGenerator::CheckOnlyEdge; |
| break; |
| default: |
| LOG(ERROR) << "Invalid check mode"; |
| break; |
| } |
| } |
| |
| void NBestGenerator::MakeCandidate(Segment::Candidate *candidate, |
| int32 cost, int32 structure_cost, |
| int32 wcost, |
| const vector<const Node *> &nodes) const { |
| CHECK(!nodes.empty()); |
| |
| candidate->Init(); |
| candidate->lid = nodes.front()->lid; |
| candidate->rid = nodes.back()->rid; |
| candidate->cost = cost; |
| candidate->structure_cost = structure_cost; |
| candidate->wcost = wcost; |
| |
| bool is_functional = false; |
| for (size_t i = 0; i < nodes.size(); ++i) { |
| const Node *node = nodes[i]; |
| DCHECK(node != NULL); |
| if (!is_functional && !pos_matcher_->IsFunctional(node->lid)) { |
| candidate->content_value += node->value; |
| candidate->content_key += node->key; |
| } else { |
| is_functional = true; |
| } |
| candidate->key += node->key; |
| candidate->value += node->value; |
| |
| if (node->constrained_prev != NULL || |
| (node->next != NULL && node->next->constrained_prev == node)) { |
| // If result has constrained_node, set CONTEXT_SENSITIVE. |
| // If a node has constrained node, the node is generated by |
| // a) compound node and resegmented via personal name resegmentation |
| // b) compound-based reranking. |
| candidate->attributes |= Segment::Candidate::CONTEXT_SENSITIVE; |
| } |
| if (node->attributes & Node::SPELLING_CORRECTION) { |
| candidate->attributes |= Segment::Candidate::SPELLING_CORRECTION; |
| } |
| if (node->attributes & Node::NO_VARIANTS_EXPANSION) { |
| candidate->attributes |= Segment::Candidate::NO_VARIANTS_EXPANSION; |
| } |
| if (node->attributes & Node::USER_DICTIONARY) { |
| candidate->attributes |= Segment::Candidate::USER_DICTIONARY; |
| } |
| } |
| |
| if (candidate->content_value.empty() || candidate->content_key.empty()) { |
| candidate->content_value = candidate->value; |
| candidate->content_key = candidate->key; |
| } |
| |
| candidate->inner_segment_boundary.clear(); |
| if (check_mode_ == ONLY_EDGE) { |
| // For realtime conversion. Set inner segment boundary for user history |
| // prediction from realtime conversion result. |
| size_t key_len = nodes[0]->key.size(), value_len = nodes[0]->value.size(); |
| size_t content_key_len = key_len, content_value_len = value_len; |
| bool is_content_boundary = false; |
| if (pos_matcher_->IsFunctional(nodes[0]->rid)) { |
| is_content_boundary = true; |
| content_key_len = 0; |
| content_value_len = 0; |
| } |
| for (size_t i = 1; i < nodes.size(); ++i) { |
| const Node *lnode = nodes[i - 1]; |
| const Node *rnode = nodes[i]; |
| const bool kMultipleSegments = false; |
| if (segmenter_->IsBoundary(lnode, rnode, kMultipleSegments)) { |
| candidate->PushBackInnerSegmentBoundary( |
| key_len, value_len, content_key_len, content_value_len); |
| key_len = 0; |
| value_len = 0; |
| content_key_len = 0; |
| content_value_len = 0; |
| is_content_boundary = false; |
| } |
| key_len += rnode->key.size(); |
| value_len += rnode->value.size(); |
| if (is_content_boundary) { |
| continue; |
| } |
| // Set boundary only after content nouns or pronouns. For example, |
| // "走った" is formed as |
| // "走っ" (content word) + "た" (functional). |
| // Since the content word is incomplete, we don't want to learn "走っ". |
| if ((pos_matcher_->IsContentNoun(lnode->rid) || |
| pos_matcher_->IsPronoun(lnode->rid)) && |
| pos_matcher_->IsFunctional(rnode->lid)) { |
| is_content_boundary = true; |
| } else { |
| content_key_len += rnode->key.size(); |
| content_value_len += rnode->value.size(); |
| } |
| } |
| candidate->PushBackInnerSegmentBoundary( |
| key_len, value_len, content_key_len, content_value_len); |
| } |
| } |
| |
| bool NBestGenerator::Next(const string &original_key, |
| Segment::Candidate *candidate, |
| Segments::RequestType request_type) { |
| DCHECK(begin_node_); |
| DCHECK(end_node_); |
| |
| DCHECK(candidate); |
| if (lattice_ == NULL || !lattice_->has_lattice()) { |
| LOG(ERROR) << "Must create lattice in advance"; |
| return false; |
| } |
| |
| // |cost| and |structure_cost| are calculated as follows: |
| // |
| // Example: |
| // |left_node| => |node1| => |node2| => |node3| => |right_node|. |
| // |node1| .. |node2| consists of a candidate. |
| // |
| // cost = (left_node->cost - begin_node_->cost) + |
| // trans(left_node, node1) + node1->wcost + |
| // trans(node1, node2) + node2->wcost + |
| // trans(node2, node3) + node3->wcost + |
| // trans(node3, rigt_node) + |
| // (right_node->cost - end_node_->cost) |
| // structure_cost = trans(node1, node2) + trans(node2, node3); |
| // wcost = node1->wcost + |
| // trans(node1, node2) + node2->wcost + |
| // trans(node2, node3) + node3->wcost |
| // |
| // Here (left_node->cost - begin_node_->cost) and |
| // (right_node->cost - end_node->cost) act as an approximation |
| // of marginalized costs of the candidate |node1| .. |node3|. |
| // "marginalized cost" means that how likely the left_node or right_node |
| // are selected by taking the all paths encoded in the lattice. |
| // These approximated costs are exactly 0 when taking Viterbi-best |
| // path. |
| |
| // Insert Viterbi best result here to make sure that |
| // the top result is Viterbi best result. |
| if (!viterbi_result_checked_) { |
| // Use CandiadteFilter so that filter is initialized with the |
| // Viterbi-best path. |
| switch (InsertTopResult(original_key, candidate, request_type)) { |
| case CandidateFilter::GOOD_CANDIDATE: |
| return true; |
| case CandidateFilter::STOP_ENUMERATION: |
| return false; |
| // Viterbi best result was tried to be inserted but reverted. |
| case CandidateFilter::BAD_CANDIDATE: |
| default: |
| // do nothing |
| break; |
| } |
| } |
| |
| const int KMaxTrial = 500; |
| int num_trials = 0; |
| |
| while (!agenda_.IsEmpty()) { |
| const QueueElement *top = agenda_.Top(); |
| DCHECK(top); |
| agenda_.Pop(); |
| const Node *rnode = top->node; |
| CHECK(rnode); |
| |
| if (num_trials++ > KMaxTrial) { // too many trials |
| VLOG(2) << "too many trials: " << num_trials; |
| return false; |
| } |
| |
| // reached to the goal. |
| if (rnode->end_pos == begin_node_->end_pos) { |
| nodes_.clear(); |
| for (const QueueElement *elm = top->next; |
| elm->next != NULL; elm = elm->next) { |
| nodes_.push_back(elm->node); |
| } |
| CHECK(!nodes_.empty()); |
| |
| MakeCandidate(candidate, top->gx, top->structure_gx, top->w_gx, nodes_); |
| const int filter_result = filter_->FilterCandidate(original_key, |
| candidate, |
| nodes_, |
| request_type); |
| nodes_.clear(); |
| |
| switch (filter_result) { |
| case CandidateFilter::GOOD_CANDIDATE: |
| return true; |
| case CandidateFilter::STOP_ENUMERATION: |
| return false; |
| case CandidateFilter::BAD_CANDIDATE: |
| default: |
| break; |
| // do nothing |
| } |
| } else { |
| const QueueElement *best_left_elm = NULL; |
| const bool is_right_edge = rnode->begin_pos == end_node_->begin_pos; |
| const bool is_left_edge = rnode->begin_pos == begin_node_->end_pos; |
| DCHECK(!(is_right_edge && is_left_edge)); |
| |
| // is_edge is true if current lnode/rnode has same boundary as |
| // begin/end node regardless of its value. |
| const bool is_edge = (is_right_edge || is_left_edge); |
| |
| for (Node *lnode = lattice_->end_nodes(rnode->begin_pos); |
| lnode != NULL; lnode = lnode->enext) { |
| // is_invalid_position is true if the lnode's location is invalid |
| // 1. |<-- begin_node_-->| |
| // |<--lnode-->| <== overlapped. |
| // |
| // 2. |<-- begin_node_-->| |
| // |<--lnode-->| <== exceeds begin_node. |
| // This case can't be happened because the |rnode| is always at just |
| // right of the |lnode|. By avoiding case1, this can't be happen. |
| // 2'. |<-- begin_node_-->| |
| // |<--lnode-->||<--rnode-->| |
| const bool is_valid_position = |
| !((lnode->begin_pos < begin_node_->end_pos && |
| begin_node_->end_pos < lnode->end_pos)); |
| if (!is_valid_position) { |
| continue; |
| } |
| |
| // If left_node is left edge, there is a cost-based constraint. |
| const bool is_valid_cost = |
| (lnode->cost - begin_node_->cost) <= kCostDiff; |
| if (is_left_edge && !is_valid_cost) { |
| continue; |
| } |
| |
| // We can omit the search for the node which has the |
| // same rid with |begin_node_| because: |
| // 1. |begin_node_| is the part of the best route. |
| // 2. The cost diff of 'LEFT_EDGE' is decided only by |
| // transition_cost for lnode. |
| // Actually, checking for each rid once is enough. |
| const bool can_omit_search = |
| lnode->rid == begin_node_->rid && lnode != begin_node_; |
| if (is_left_edge && can_omit_search) { |
| continue; |
| } |
| |
| DCHECK(this->boundary_checker_ != NULL); |
| BoundaryCheckResult boundary_result = (this->*boundary_checker_)( |
| lnode, rnode, is_edge); |
| if (boundary_result == INVALID) { |
| continue; |
| } |
| |
| // We can expand candidates from |rnode| to |lnode|. |
| const int transition_cost = GetTransitionCost(lnode, rnode); |
| |
| // How likely the costs get increased after expanding rnode. |
| int cost_diff = 0; |
| int structure_cost_diff = 0; |
| int wcost_diff = 0; |
| |
| if (is_right_edge) { |
| // use |rnode->cost - end_node_->cost| is an approximation |
| // of marginalized word cost. |
| cost_diff = transition_cost + (rnode->cost - end_node_->cost); |
| structure_cost_diff = 0; |
| wcost_diff = 0; |
| } else if (is_left_edge) { |
| // use |lnode->cost - begin_node_->cost| is an approximation |
| // of marginalized word cost. |
| cost_diff = transition_cost + rnode->wcost + |
| (lnode->cost - begin_node_->cost); |
| structure_cost_diff = 0; |
| wcost_diff = rnode->wcost; |
| } else { |
| // use rnode->wcost. |
| cost_diff = transition_cost + rnode->wcost; |
| structure_cost_diff = transition_cost; |
| wcost_diff = transition_cost + rnode->wcost; |
| } |
| |
| if (boundary_result == VALID_WEAK_CONNECTED) { |
| const int kWeakConnectedPenalty = 3453; // log prob of 1/1000 |
| cost_diff += kWeakConnectedPenalty; |
| structure_cost_diff += kWeakConnectedPenalty / 2; |
| wcost_diff += kWeakConnectedPenalty / 2; |
| } |
| |
| const int32 gx = cost_diff + top->gx; |
| // |lnode->cost| is heuristics function of A* search, h(x). |
| // After Viterbi search, we already know an exact value of h(x). |
| const int32 fx = lnode->cost + gx; |
| const int32 structure_gx = structure_cost_diff + top->structure_gx; |
| const int32 w_gx = wcost_diff + top->w_gx; |
| if (is_left_edge) { |
| // We only need to only 1 left node here. |
| // Even if expand all left nodes, all the |value| part should |
| // be identical. Here, we simply use the best left edge node. |
| // This hack reduces the number of redundant calls of pop(). |
| if (best_left_elm == NULL || best_left_elm->fx > fx) { |
| best_left_elm = CreateNewElement( |
| lnode, top, fx, gx, structure_gx, w_gx); |
| } |
| } else { |
| agenda_.Push(CreateNewElement( |
| lnode, top, fx, gx, structure_gx, w_gx)); |
| } |
| } |
| |
| if (best_left_elm != NULL) { |
| agenda_.Push(best_left_elm); |
| } |
| } |
| } |
| |
| return false; |
| } |
| |
| NBestGenerator::BoundaryCheckResult NBestGenerator::CheckOnlyMid( |
| const Node *lnode, const Node *rnode, bool is_edge) const { |
| // Special case, no boundary check |
| if (rnode->node_type == Node::CON_NODE || |
| lnode->node_type == Node::CON_NODE) { |
| return VALID; |
| } |
| |
| // is_boundary is true if there is a grammer-based boundary |
| // between lnode and rnode |
| const bool is_boundary = (lnode->node_type == Node::HIS_NODE || |
| segmenter_->IsBoundary(lnode, rnode, false)); |
| if (!is_edge && is_boundary) { |
| // There is a boundary within the segment. |
| return INVALID; |
| } |
| |
| if (is_edge && !is_boundary) { |
| // Here is not the boundary gramatically, but segmented by |
| // other reason. |
| return VALID_WEAK_CONNECTED; |
| } |
| |
| return VALID; |
| } |
| |
| NBestGenerator::BoundaryCheckResult NBestGenerator::CheckOnlyEdge( |
| const Node *lnode, const Node *rnode, bool is_edge) const { |
| // Special case, no boundary check |
| if (rnode->node_type == Node::CON_NODE || |
| lnode->node_type == Node::CON_NODE) { |
| return VALID; |
| } |
| |
| // is_boundary is true if there is a grammer-based boundary |
| // between lnode and rnode |
| const bool is_boundary = ( |
| lnode->node_type == Node::HIS_NODE || |
| segmenter_->IsBoundary(lnode, rnode, true)); |
| if (is_edge != is_boundary) { |
| // on the edge, have a boudnary. |
| // not on the edge, not the case. |
| return INVALID; |
| } else { |
| return VALID; |
| } |
| } |
| |
| NBestGenerator::BoundaryCheckResult NBestGenerator::CheckStrict( |
| const Node *lnode, const Node *rnode, bool is_edge) const { |
| // Special case, no boundary check |
| if (rnode->node_type == Node::CON_NODE || |
| lnode->node_type == Node::CON_NODE) { |
| return VALID; |
| } |
| |
| // is_boundary is true if there is a grammer-based boundary |
| // between lnode and rnode |
| const bool is_boundary = ( |
| lnode->node_type == Node::HIS_NODE || |
| segmenter_->IsBoundary(lnode, rnode, false)); |
| |
| if (is_edge != is_boundary) { |
| // on the edge, have a boudnary. |
| // not on the edge, not the case. |
| return INVALID; |
| } else { |
| return VALID; |
| } |
| } |
| |
| int NBestGenerator::InsertTopResult(const string &original_key, |
| Segment::Candidate *candidate, |
| Segments::RequestType request_type) { |
| nodes_.clear(); |
| int total_wcost = 0; |
| for (const Node *node = begin_node_->next; |
| node != end_node_; node = node->next) { |
| nodes_.push_back(node); |
| if (node != begin_node_->next) { |
| total_wcost += node->wcost; |
| } |
| } |
| DCHECK(!nodes_.empty()); |
| |
| const int cost = end_node_->cost - |
| begin_node_->cost - end_node_->wcost; |
| const int structure_cost = end_node_->prev->cost - |
| begin_node_->next->cost - total_wcost; |
| const int wcost = end_node_->prev->cost - |
| begin_node_->next->cost + begin_node_->next->wcost; |
| |
| MakeCandidate(candidate, cost, structure_cost, wcost, nodes_); |
| |
| if (request_type == Segments::SUGGESTION) { |
| candidate->attributes |= Segment::Candidate::REALTIME_CONVERSION; |
| } |
| |
| viterbi_result_checked_ = true; |
| const int result = filter_->FilterCandidate( |
| original_key, candidate, nodes_, request_type); |
| nodes_.clear(); |
| return result; |
| } |
| |
| int NBestGenerator::GetTransitionCost(const Node *lnode, |
| const Node *rnode) const { |
| const int kInvalidPenaltyCost = 100000; |
| if (rnode->constrained_prev != NULL && lnode != rnode->constrained_prev) { |
| return kInvalidPenaltyCost; |
| } |
| return connector_->GetTransitionCost(lnode->rid, rnode->lid); |
| } |
| |
| } // namespace mozc |