001 package org.maltparser.parser.guide.instance; 002 003 import java.io.BufferedReader; 004 import java.io.BufferedWriter; 005 import java.io.IOException; 006 import java.util.ArrayList; 007 import java.util.Collections; 008 import java.util.HashMap; 009 import java.util.HashSet; 010 import java.util.LinkedList; 011 import java.util.List; 012 import java.util.Map; 013 import java.util.Set; 014 import java.util.SortedMap; 015 import java.util.TreeMap; 016 import java.util.Map.Entry; 017 import java.util.regex.Pattern; 018 019 import org.maltparser.core.config.ConfigurationDir; 020 import org.maltparser.core.exception.MaltChainedException; 021 import org.maltparser.core.feature.FeatureException; 022 import org.maltparser.core.feature.FeatureVector; 023 import org.maltparser.core.feature.function.FeatureFunction; 024 import org.maltparser.core.feature.function.Modifiable; 025 import org.maltparser.core.feature.value.SingleFeatureValue; 026 import org.maltparser.core.syntaxgraph.DependencyStructure; 027 import org.maltparser.parser.guide.ClassifierGuide; 028 import org.maltparser.parser.guide.GuideException; 029 import org.maltparser.parser.guide.Model; 030 import org.maltparser.parser.history.action.SingleDecision; 031 032 /** 033 * This class implements a decision tree model. The class is recursive and an 034 * instance of the class can be a root model or belong to an other decision tree 035 * model. Every node in the decision tree is represented by an instance of the 036 * class. Node can be in one of the three states branch model, leaf model or not 037 * decided. A branch model has several sub decision tree models and a leaf model 038 * owns an atomic model that is used to classify instances. When a decision tree 039 * model is in the not decided state it has both sub decision trees and an 040 * atomic model. It can be in the not decided state during training before it is 041 * tested by cross validation if the sub decision tree models provide better 042 * accuracy than the atomic model. 043 * 044 * 045 * @author Kjell Winblad 046 */ 047 public class DecisionTreeModel implements InstanceModel { 048 049 /* 050 * The leaf nodes needs a int index that is unique among all leaf nodes 051 * because they have an AtomicModel which need such an index. 052 */ 053 private static int leafModelIndexConter = 0; 054 055 private final static int OTHER_BRANCH_ID = 1000000;// Integer.MAX_VALUE; 056 057 // The number of division used when doing cross validation test 058 private int numberOfCrossValidationSplits = 10; 059 /* 060 * Cross validation accuracy is calculated for every node during training 061 * This should be calculated for every node and is set to -1.0 if it isn't 062 * calculated yet 063 */ 064 private final static double CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE = -1.0; 065 private double crossValidationAccuracy = CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE; 066 // The parent model 067 private Model parent = null; 068 // An ordered list of features to divide on 069 private LinkedList<FeatureFunction> divideFeatures = null; 070 /* 071 * The branches of the tree Is set to null if this is a leaf node 072 */ 073 private SortedMap<Integer, DecisionTreeModel> branches = null; 074 075 /* 076 * This model is used if this is a leaf node Is set to null if this is a 077 * branch node 078 */ 079 private AtomicModel leafModel = null; 080 // Number of training instances added 081 private int frequency = 0; 082 /* 083 * min number of instances for a node to existAll sub nodes with less 084 * instances will be concatenated to one sub node 085 */ 086 private int divideThreshold = 0; 087 // The feature vector for this problem 088 private FeatureVector featureVector; 089 090 private FeatureVector subFeatureVector = null; 091 092 // Used to indicate that the modelIndex field is not set 093 private static final int MODEL_INDEX_NOT_SET = Integer.MIN_VALUE; 094 /* 095 * Model index is the identifier used to distinguish this model from other 096 * models at the same level. This should not be used in the root model and 097 * has the value MODEL_INDEX_NOT_SET in it. 098 */ 099 private int modelIndex = MODEL_INDEX_NOT_SET; 100 // Indexes of the column used to divide on 101 private ArrayList<Integer> divideFeatureIndexVector; 102 103 private boolean automaticSplit = false; 104 private boolean treeForceDivide = false; 105 106 /** 107 * Constructs a feature divide model. 108 * 109 * @param featureVector 110 * the feature vector used by the decision tree model 111 * @param parent 112 * the parent guide model. 113 * @throws MaltChainedException 114 */ 115 public DecisionTreeModel(FeatureVector featureVector, Model parent) 116 throws MaltChainedException { 117 118 this.featureVector = featureVector; 119 this.divideFeatures = new LinkedList<FeatureFunction>(); 120 setParent(parent); 121 setFrequency(0); 122 initDecisionTreeParam(); 123 124 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 125 126 // Prepare for training 127 128 branches = new TreeMap<Integer, DecisionTreeModel>(); 129 leafModel = new AtomicModel(-1, featureVector, this); 130 131 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 132 load(); 133 } 134 } 135 136 /* 137 * This constructor is used from within objects of the class to create sub decision tree models 138 * 139 * 140 */ 141 private DecisionTreeModel(int modelIndex, FeatureVector featureVector, 142 Model parent, LinkedList<FeatureFunction> divideFeatures, 143 int divideThreshold) throws MaltChainedException { 144 145 this.featureVector = featureVector; 146 147 setParent(parent); 148 setFrequency(0); 149 150 this.modelIndex = modelIndex; 151 this.divideFeatures = divideFeatures; 152 this.divideThreshold = divideThreshold; 153 154 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 155 156 //Create the divide feature index vector 157 if (divideFeatures.size() > 0) { 158 159 divideFeatureIndexVector = new ArrayList<Integer>(); 160 for (int i = 0; i < featureVector.size(); i++) { 161 if (featureVector.get(i).equals(divideFeatures.get(0))) { 162 divideFeatureIndexVector.add(i); 163 } 164 } 165 166 } 167 leafModelIndexConter++; 168 169 170 // Prepare for training 171 branches = new TreeMap<Integer, DecisionTreeModel>(); 172 leafModel = new AtomicModel(-1, featureVector, this); 173 174 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 175 load(); 176 } 177 } 178 179 /** 180 * Loads the feature divide model settings .fsm file. 181 * 182 * @throws MaltChainedException 183 */ 184 protected void load() throws MaltChainedException { 185 186 ConfigurationDir configDir = getGuide().getConfiguration() 187 .getConfigurationDir(); 188 189 190 // load the dsm file 191 192 try { 193 194 final BufferedReader in = new BufferedReader( 195 configDir.getInputStreamReaderFromConfigFile(getModelName() 196 + ".dsm")); 197 final Pattern tabPattern = Pattern.compile("\t"); 198 199 boolean first = true; 200 while (true) { 201 String line = in.readLine(); 202 if (line == null) 203 break; 204 String[] cols = tabPattern.split(line); 205 if (cols.length != 2) { 206 throw new GuideException(""); 207 } 208 int code = -1; 209 int freq = 0; 210 try { 211 code = Integer.parseInt(cols[0]); 212 freq = Integer.parseInt(cols[1]); 213 } catch (NumberFormatException e) { 214 throw new GuideException( 215 "Could not convert a string value into an integer value when loading the feature divide model settings (.fsm). ", 216 e); 217 } 218 219 if (code == MODEL_INDEX_NOT_SET) { 220 if (!first) 221 throw new GuideException( 222 "Error in config file '" 223 + getModelName() 224 + ".dsm" 225 + "'. If the index in the .dsm file is MODEL_INDEX_NOT_SET it should be the first."); 226 227 first = false; 228 // It is a leaf node 229 // Create atomic model for the leaf node 230 leafModel = new AtomicModel(-1, featureVector, this); 231 232 // setIsLeafNode(); 233 234 } else { 235 if (first) { 236 // Create the branches holder 237 238 branches = new TreeMap<Integer, DecisionTreeModel>(); 239 240 // setIsBranchNode(); 241 242 first = false; 243 } 244 245 if (branches == null) 246 throw new GuideException( 247 "Error in config file '" 248 + getModelName() 249 + ".dsm" 250 + "'. If MODEL_INDEX_NOT_SET is the first model index in the .dsm file it should be the only."); 251 252 if (code == OTHER_BRANCH_ID) 253 branches.put(code, new DecisionTreeModel(code, 254 featureVector, this, 255 new LinkedList<FeatureFunction>(), 256 divideThreshold)); 257 else 258 branches.put(code, new DecisionTreeModel(code, 259 getSubFeatureVector(), this, 260 createNextLevelDivideFeatures(), 261 divideThreshold)); 262 263 branches.get(code).setFrequency(freq); 264 265 setFrequency(getFrequency() + freq); 266 267 } 268 269 } 270 in.close(); 271 272 } catch (IOException e) { 273 throw new GuideException( 274 "Could not read from the guide model settings file '" 275 + getModelName() + ".dsm" + "', when " 276 + "loading the guide model settings. ", e); 277 } 278 279 } 280 281 private void initDecisionTreeParam() throws MaltChainedException { 282 String treeSplitColumns = getGuide().getConfiguration().getOptionValue( 283 "guide", "tree_split_columns").toString(); 284 String treeSplitStructures = getGuide().getConfiguration() 285 .getOptionValue("guide", "tree_split_structures").toString(); 286 287 automaticSplit = getGuide().getConfiguration() 288 .getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes"); 289 290 treeForceDivide = getGuide().getConfiguration() 291 .getOptionValue("guide", "tree_force_divide").toString().equals("yes"); 292 293 if(automaticSplit){ 294 divideFeatures = new LinkedList<FeatureFunction>(); 295 for(FeatureFunction feature:featureVector){ 296 if(feature.getFeatureValue() instanceof SingleFeatureValue) 297 divideFeatures.add(feature); 298 } 299 300 301 }else{ 302 303 if (treeSplitColumns == null || treeSplitColumns.length() == 0) { 304 throw new GuideException( 305 "The option '--guide-tree_split_columns' cannot be found, when initializing the decision tree model. "); 306 } 307 308 if (treeSplitStructures == null || treeSplitStructures.length() == 0) { 309 throw new GuideException( 310 "The option '--guide-tree_split_structures' cannot be found, when initializing the decision tree model. "); 311 } 312 313 String[] treeSplitColumnsArray = treeSplitColumns.split("@"); 314 String[] treeSplitStructuresArray = treeSplitStructures.split("@"); 315 316 if (treeSplitColumnsArray.length != treeSplitStructuresArray.length) 317 throw new GuideException( 318 "The option '--guide-tree_split_structures' and '--guide-tree_split_columns' must be followed by a ; separated lists of the same length"); 319 320 try { 321 322 for (int n = 0; n < treeSplitColumnsArray.length; n++) { 323 324 final String spec = "InputColumn(" 325 + treeSplitColumnsArray[n].trim() + ", " 326 + treeSplitStructuresArray[n].trim() + ")"; 327 328 divideFeatures.addLast(featureVector.getFeatureModel() 329 .identifyFeature(spec)); 330 } 331 332 } catch (FeatureException e) { 333 throw new GuideException("The data split feature 'InputColumn(" 334 + getGuide().getConfiguration().getOptionValue("guide", 335 "data_split_column").toString() 336 + ", " 337 + getGuide().getConfiguration().getOptionValue("guide", 338 "data_split_structure").toString() 339 + ") cannot be initialized. ", e); 340 } 341 342 for (FeatureFunction divideFeature : divideFeatures) { 343 if (!(divideFeature instanceof Modifiable)) { 344 throw new GuideException("The data split feature 'InputColumn(" 345 + getGuide().getConfiguration().getOptionValue("guide", 346 "data_split_column").toString() 347 + ", " 348 + getGuide().getConfiguration().getOptionValue("guide", 349 "data_split_structure").toString() 350 + ") does not implement Modifiable interface. "); 351 } 352 } 353 354 divideFeatureIndexVector = new ArrayList<Integer>(); 355 for (int i = 0; i < featureVector.size(); i++) { 356 357 if (featureVector.get(i).equals(divideFeatures.get(0))) { 358 359 divideFeatureIndexVector.add(i); 360 } 361 } 362 363 if (divideFeatureIndexVector.size() == 0) { 364 throw new GuideException( 365 "Could not match the given divide features to any of the available features."); 366 } 367 368 369 370 } 371 372 try { 373 374 String treeSplitTreshold = getGuide().getConfiguration() 375 .getOptionValue("guide", "tree_split_threshold").toString(); 376 377 if (treeSplitTreshold != null && treeSplitTreshold.length() > 0) { 378 379 divideThreshold = Integer.parseInt(treeSplitTreshold); 380 381 } else { 382 divideThreshold = 0; 383 } 384 } catch (NumberFormatException e) { 385 throw new GuideException( 386 "The --guide-tree_split_threshold option is not an integer value. ", 387 e); 388 } 389 390 try { 391 392 String treeNumberOfCrossValidationDivisions = getGuide() 393 .getConfiguration().getOptionValue("guide", 394 "tree_number_of_cross_validation_divisions") 395 .toString(); 396 397 if (treeNumberOfCrossValidationDivisions != null 398 && treeNumberOfCrossValidationDivisions.length() > 0) { 399 400 numberOfCrossValidationSplits = Integer 401 .parseInt(treeNumberOfCrossValidationDivisions); 402 403 } else { 404 divideThreshold = 0; 405 } 406 } catch (NumberFormatException e) { 407 throw new GuideException( 408 "The --guide-tree_number_of_cross_validation_divisions option is not an integer value. ", 409 e); 410 } 411 412 } 413 414 @Override 415 public void addInstance(SingleDecision decision) 416 throws MaltChainedException { 417 418 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 419 throw new GuideException("Can only add instance during learning. "); 420 } else if (divideFeatures.size() > 0) { 421 //FeatureFunction divideFeature = divideFeatures.getFirst(); 422 423 for (FeatureFunction divideFeature : divideFeatures) { 424 if (!(divideFeature.getFeatureValue() instanceof SingleFeatureValue)) { 425 throw new GuideException( 426 "The divide feature does not have a single value. "); 427 } 428 // Is this necessary? 429 divideFeature.update(); 430 } 431 leafModel.addInstance(decision); 432 433 //Update statistics data 434 updateStatistics(decision); 435 436 437 } else { 438 // Model has already been decided. It is a leaf node 439 if (branches != null) 440 setIsLeafNode(); 441 442 leafModel.addInstance(decision); 443 444 //Update statistics data 445 updateStatistics(decision); 446 447 } 448 449 450 451 452 } 453 454 /* 455 private class StatisticsItem{ 456 457 private int columnValue; 458 459 private int classValue; 460 461 public StatisticsItem(int columnValue, int classValue) { 462 super(); 463 this.columnValue = columnValue; 464 this.classValue = classValue; 465 } 466 467 public int getColumnValue() { 468 return columnValue; 469 } 470 471 public int getClassValue() { 472 return classValue; 473 } 474 475 @Override 476 public int hashCode() { 477 return new Integer(columnValue/2).hashCode() + new Integer(classValue/2).hashCode(); 478 } 479 480 @Override 481 public boolean equals(Object obj) { 482 483 StatisticsItem compItem = (StatisticsItem)obj; 484 485 return compItem.getClassValue()==this.getClassValue() && compItem.getColumnValue()==this.getColumnValue(); 486 } 487 } 488 */ 489 490 /* 491 * Helper method used for automatic division by gain ratio 492 * @param n 493 * @return 494 */ 495 private double log2(double n){ 496 return Math.log(n)/Math.log(2.0); 497 } 498 499 /* 500 * This map contains one item per element in the divideFeatures. Mappings exist from every Feature function 501 * in divideFeatures to a corresponding Statistics Item list that contains statistics for that divide feature. 502 * In all positions in the list are a list of StatisticsItems one for every unique feature class 503 * combination in the column. The statistics item also contain a count of that combination. 504 */ 505 //private HashMap<FeatureFunction, HashMap<StatisticsItem, Integer>> statisticsForDivideFatureMap = null; 506 //The keys are class id's and the value is a count of the number of this 507 private HashMap<Integer,Integer> classIdToCountMap = null; 508 509 private HashMap<FeatureFunction, HashMap<Integer,Integer>> featureIdToCountMap = null; 510 511 //private HashMap<FeatureFunction, HashMap<Integer,Integer>> classIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer,Integer>>(); 512 513 private HashMap<FeatureFunction, HashMap<Integer,HashMap<Integer,Integer>>> featureIdToClassIdToCountMap = null; 514 515 private void updateStatistics(SingleDecision decision) 516 throws MaltChainedException { 517 518 // if(statisticsForDivideFatureMap==null){ 519 // statisticsForDivideFatureMap = new HashMap<FeatureFunction, 520 // HashMap<StatisticsItem, Integer>>(); 521 // 522 // for(FeatureFunction columnsDivideFeature : divideFeatures) 523 // statisticsForDivideFatureMap.put(columnsDivideFeature, new 524 // HashMap<StatisticsItem, Integer>()); 525 // } 526 // 527 // 528 // int instanceClass = decision.getDecisionCode(); 529 // 530 // Integer classCount = classCountStatistics.get(instanceClass); 531 // 532 // if(classCount==null){ 533 // classCount=0; 534 // } 535 // 536 // classCountStatistics.put(instanceClass, classCount+1); 537 // 538 // for(FeatureFunction columnsDivideFeature : featureVector){ 539 // 540 // int featureCode = 541 // ((SingleFeatureValue)columnsDivideFeature.getFeatureValue()).getCode(); 542 // HashMap<StatisticsItem, Integer> statisticsMap = 543 // statisticsForDivideFatureMap.get(columnsDivideFeature); 544 // if(statisticsMap!=null){ 545 // 546 // StatisticsItem item = new StatisticsItem(featureCode, instanceClass); 547 // 548 // Integer count = statisticsMap.get(item); 549 // 550 // if(count==null){ 551 // //Add the statistic item to the map 552 // count = 0; 553 // } 554 // 555 // statisticsMap.put(item, count + 1); 556 // 557 // } 558 // 559 // } 560 561 // If it is not done initialize the statistics maps 562 if (featureIdToCountMap == null) { 563 564 featureIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer, Integer>>(); 565 566 for (FeatureFunction columnsDivideFeature : divideFeatures) 567 featureIdToCountMap.put(columnsDivideFeature, 568 new HashMap<Integer, Integer>()); 569 570 571 featureIdToClassIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer, HashMap<Integer, Integer>>>(); 572 573 for (FeatureFunction columnsDivideFeature : divideFeatures) 574 featureIdToClassIdToCountMap.put(columnsDivideFeature, 575 new HashMap<Integer, HashMap<Integer, Integer>>()); 576 577 classIdToCountMap = new HashMap<Integer, Integer>(); 578 579 } 580 581 int instanceClass = decision.getDecisionCode(); 582 583 // Increase classCountStatistics 584 585 Integer classCount = classIdToCountMap.get(instanceClass); 586 587 if (classCount == null) { 588 classCount = 0; 589 } 590 591 classIdToCountMap.put(instanceClass, classCount + 1); 592 593 // Increase featureIdToCountMap 594 595 for (FeatureFunction columnsDivideFeature : divideFeatures) { 596 597 int featureCode = ((SingleFeatureValue) columnsDivideFeature 598 .getFeatureValue()).getCode(); 599 600 HashMap<Integer, Integer> statisticsMap = featureIdToCountMap 601 .get(columnsDivideFeature); 602 603 Integer count = statisticsMap.get(featureCode); 604 605 if (count == null) { 606 // Add the statistic item to the map 607 count = 0; 608 } 609 610 statisticsMap.put(featureCode, count + 1); 611 612 } 613 614 // Increase featureIdToClassIdToCountMap 615 616 for (FeatureFunction columnsDivideFeature : divideFeatures) { 617 618 int featureCode = ((SingleFeatureValue) columnsDivideFeature 619 .getFeatureValue()).getCode(); 620 621 HashMap<Integer, HashMap<Integer, Integer>> featureIdToclassIdToCountMapTmp = featureIdToClassIdToCountMap 622 .get(columnsDivideFeature); 623 624 HashMap<Integer, Integer> classIdToCountMapTmp = featureIdToclassIdToCountMapTmp.get(featureCode); 625 626 if (classIdToCountMapTmp == null) { 627 // Add the statistic item to the map 628 classIdToCountMapTmp = new HashMap<Integer, Integer>(); 629 630 featureIdToclassIdToCountMapTmp.put(featureCode, classIdToCountMapTmp); 631 } 632 633 Integer count = classIdToCountMapTmp.get(instanceClass); 634 635 if (count == null) { 636 // Add the statistic item to the map 637 count = 0; 638 } 639 640 classIdToCountMapTmp.put(instanceClass, count + 1); 641 642 } 643 644 } 645 646 @SuppressWarnings("unchecked") 647 private LinkedList<FeatureFunction> createNextLevelDivideFeatures() { 648 649 LinkedList<FeatureFunction> nextLevelDivideFeatures = (LinkedList<FeatureFunction>) divideFeatures 650 .clone(); 651 652 nextLevelDivideFeatures.removeFirst(); 653 654 return nextLevelDivideFeatures; 655 } 656 657 /* 658 * Removes the current divide feature from the feature vector so it is not 659 * present in the sub node 660 */ 661 private FeatureVector getSubFeatureVector() { 662 663 if (subFeatureVector != null) 664 return subFeatureVector; 665 666 FeatureFunction divideFeature = divideFeatures.getFirst(); 667 668 ArrayList<Integer> divideFeatureIndexVector = new ArrayList<Integer>(); 669 for (int i = 0; i < featureVector.size(); i++) { 670 if (featureVector.get(i).equals(divideFeature)) { 671 divideFeatureIndexVector.add(i); 672 } 673 } 674 675 FeatureVector divideFeatureVector = (FeatureVector) featureVector 676 .clone(); 677 678 for (Integer i : divideFeatureIndexVector) { 679 divideFeatureVector.remove(divideFeatureVector.get(i)); 680 } 681 682 subFeatureVector = divideFeatureVector; 683 684 return divideFeatureVector; 685 } 686 687 @Override 688 public FeatureVector extract() throws MaltChainedException { 689 690 return getCurrentAtomicModel().extract(); 691 692 } 693 694 /* 695 * Returns the atomic model that is effected by this parsing step 696 */ 697 private AtomicModel getCurrentAtomicModel() throws MaltChainedException { 698 699 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 700 throw new GuideException("Can only predict during parsing. "); 701 } 702 703 if (branches == null && leafModel != null) 704 return leafModel; 705 706 FeatureFunction divideFeature = divideFeatures.getFirst(); 707 708 if (!(divideFeature.getFeatureValue() instanceof SingleFeatureValue)) { 709 throw new GuideException( 710 "The divide feature does not have a single value. "); 711 } 712 713 if (branches != null 714 && branches.containsKey(((SingleFeatureValue) divideFeature 715 .getFeatureValue()).getCode())) { 716 return branches.get( 717 ((SingleFeatureValue) divideFeature.getFeatureValue()) 718 .getCode()).getCurrentAtomicModel(); 719 } else if (branches.containsKey(OTHER_BRANCH_ID) 720 && branches.get(OTHER_BRANCH_ID).getFrequency() > 0) { 721 return branches.get(OTHER_BRANCH_ID).getCurrentAtomicModel(); 722 } else { 723 getGuide() 724 .getConfiguration() 725 .getConfigLogger() 726 .info( 727 "Could not predict the next parser decision because there is " 728 + "no divide or master model that covers the divide value '" 729 + ((SingleFeatureValue) divideFeature 730 .getFeatureValue()).getCode() 731 + "', as default" 732 + " class code '1' is used. "); 733 } 734 return null; 735 } 736 737 /** 738 * Increase the frequency by 1 739 */ 740 public void increaseFrequency() { 741 frequency++; 742 } 743 744 public void decreaseFrequency() { 745 frequency--; 746 } 747 748 @Override 749 public boolean predict(SingleDecision decision) throws MaltChainedException { 750 751 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 752 throw new GuideException("Can only predict during parsing. "); 753 } else if (divideFeatures.size() > 0 754 && !(divideFeatures.getFirst().getFeatureValue() instanceof SingleFeatureValue)) { 755 throw new GuideException( 756 "The divide feature does not have a single value. "); 757 } 758 759 760 if (branches != null 761 && branches.containsKey(((SingleFeatureValue) divideFeatures 762 .getFirst().getFeatureValue()).getCode())) { 763 764 return branches.get( 765 ((SingleFeatureValue) divideFeatures.getFirst() 766 .getFeatureValue()).getCode()).predict(decision); 767 } else if (branches != null && branches.containsKey(OTHER_BRANCH_ID)) { 768 769 return branches.get(OTHER_BRANCH_ID).predict(decision); 770 } else if (leafModel != null) { 771 772 return leafModel.predict(decision); 773 } else { 774 775 getGuide() 776 .getConfiguration() 777 .getConfigLogger() 778 .info( 779 "Could not predict the next parser decision because there is " 780 + "no divide or master model that covers the divide value '" 781 + ((SingleFeatureValue) divideFeatures 782 .getFirst().getFeatureValue()) 783 .getCode() + "', as default" 784 + " class code '1' is used. "); 785 786 decision.addDecision(1); // default prediction 787 // classCodeTable.getEmptyKBestList().addKBestItem(1); 788 } 789 return true; 790 } 791 792 @Override 793 public FeatureVector predictExtract(SingleDecision decision) 794 throws MaltChainedException { 795 return getCurrentAtomicModel().predictExtract(decision); 796 } 797 798 /* 799 * Decides if this is a branch or leaf node by doing cross validation and 800 * returns the cross validation score for this node 801 */ 802 private double decideNodeType() throws MaltChainedException { 803 804 // We don't want to do this twice test 805 if (crossValidationAccuracy != CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE) 806 return crossValidationAccuracy; 807 808 if (modelIndex == MODEL_INDEX_NOT_SET) 809 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) { 810 getGuide().getConfiguration().getConfigLogger().info( 811 "Starting deph first pruning of the decision tree\n"); 812 } 813 814 long start = System.currentTimeMillis(); 815 816 double leafModelCrossValidationAccuracy = 0.0; 817 818 if(treeForceDivide) 819 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) { 820 getGuide().getConfiguration().getConfigLogger().info( 821 "Skipping cross validation of the root node since the flag tree_force_divide is set to yes. " + 822 "The cross validation score for the root node is set to zero.\n"); 823 } 824 825 if(!treeForceDivide) 826 leafModelCrossValidationAccuracy = leafModel.getMethod() 827 .crossValidate(featureVector, numberOfCrossValidationSplits); 828 829 long stop = System.currentTimeMillis(); 830 831 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) { 832 getGuide().getConfiguration().getConfigLogger().info( 833 "Cross Validation Time: " + (stop - start) + " ms" 834 + " for model " + getModelName() + "\n"); 835 } 836 837 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) { 838 getGuide().getConfiguration().getConfigLogger().info( 839 "Cross Validation Accuracy as leaf node = " 840 + leafModelCrossValidationAccuracy + " for model " 841 + getModelName() + "\n"); 842 } 843 844 if (branches == null && leafModel != null) {// If it is already decided 845 // that this is a leaf node 846 847 crossValidationAccuracy = leafModelCrossValidationAccuracy; 848 849 return crossValidationAccuracy; 850 851 } 852 853 int totalFrequency = 0; 854 double totalAccuracyCount = 0.0; 855 // Calculate crossValidationAccuracy for branch nodes 856 for (DecisionTreeModel b : branches.values()) { 857 858 double bAccuracy = b.decideNodeType(); 859 860 totalFrequency = totalFrequency + b.getFrequency(); 861 862 totalAccuracyCount = totalAccuracyCount + bAccuracy 863 * b.getFrequency(); 864 865 } 866 867 double branchModelCrossValidationAccuracy = totalAccuracyCount 868 / totalFrequency; 869 870 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) { 871 getGuide().getConfiguration().getConfigLogger().info( 872 "Total Cross Validation Accuracy for branches = " 873 + branchModelCrossValidationAccuracy 874 + " for model " + getModelName() + "\n"); 875 } 876 877 // Finally decide which model to use 878 if (branchModelCrossValidationAccuracy > leafModelCrossValidationAccuracy) { 879 880 setIsBranchNode(); 881 882 crossValidationAccuracy = branchModelCrossValidationAccuracy; 883 884 return crossValidationAccuracy; 885 886 } else { 887 888 setIsLeafNode(); 889 890 crossValidationAccuracy = leafModelCrossValidationAccuracy; 891 892 return crossValidationAccuracy; 893 894 } 895 896 } 897 898 @Override 899 public void train() throws MaltChainedException { 900 901 // Decide node type 902 // This operation is more expensive than the training itself 903 decideNodeType(); 904 905 // Do the training depending on which type of node this is 906 if (branches == null && leafModel != null) { 907 908 // If it is a leaf node 909 910 leafModel.train(); 911 912 save(); 913 914 leafModel.terminate(); 915 916 } else { 917 // It is a branch node 918 919 for (DecisionTreeModel b : branches.values()) 920 b.train(); 921 922 save(); 923 924 for (DecisionTreeModel b : branches.values()) 925 b.terminate(); 926 927 } 928 terminate(); 929 930 } 931 932 /** 933 * Saves the decision tree model settings .dsm file. 934 * 935 * @throws MaltChainedException 936 */ 937 private void save() throws MaltChainedException { 938 try { 939 940 final BufferedWriter out = new BufferedWriter(getGuide() 941 .getConfiguration().getConfigurationDir() 942 .getOutputStreamWriter(getModelName() + ".dsm")); 943 944 if (branches != null) { 945 for (DecisionTreeModel b : branches.values()) { 946 out.write(b.getModelIndex() + "\t" + b.getFrequency() 947 + "\n"); 948 } 949 } else { 950 out.write(MODEL_INDEX_NOT_SET + "\t" + getFrequency() + "\n"); 951 } 952 953 out.close(); 954 955 } catch (IOException e) { 956 throw new GuideException( 957 "Could not write to the guide model settings file '" 958 + getModelName() + ".dsm" 959 + "' or the name mapping file '" + getModelName() 960 + ".nmf" + "', when " 961 + "saving the guide model settings to files. ", e); 962 } 963 } 964 965 @Override 966 public void finalizeSentence(DependencyStructure dependencyGraph) 967 throws MaltChainedException { 968 969 if (branches != null) { 970 971 for (DecisionTreeModel b : branches.values()) { 972 b.finalizeSentence(dependencyGraph); 973 } 974 975 } else if (leafModel != null) { 976 977 leafModel.finalizeSentence(dependencyGraph); 978 979 } else { 980 981 throw new GuideException( 982 "The feature divide models cannot be found. "); 983 984 } 985 986 } 987 988 @Override 989 public ClassifierGuide getGuide() { 990 return parent.getGuide(); 991 } 992 993 @Override 994 public String getModelName() throws MaltChainedException { 995 try { 996 997 return parent.getModelName() 998 + (modelIndex == MODEL_INDEX_NOT_SET ? "" 999 : ("_" + modelIndex)); 1000 } catch (NullPointerException e) { 1001 throw new GuideException( 1002 "The parent guide model cannot be found. ", e); 1003 } 1004 } 1005 1006 /* 1007 * This is called to define this node as to be in the leaf state. It sets branches to null. 1008 */ 1009 private void setIsLeafNode() throws MaltChainedException { 1010 1011 if (branches == null && leafModel != null) 1012 return; 1013 1014 if (branches != null && leafModel != null) { 1015 1016 for (DecisionTreeModel t : branches.values()) 1017 t.terminate(); 1018 1019 branches = null; 1020 1021 } else 1022 throw new MaltChainedException( 1023 "Can't set a node that have aleready been set to a leaf node."); 1024 1025 } 1026 /* 1027 * This is called to define this node as to be in the branch state. It sets leafModel to null. 1028 */ 1029 private void setIsBranchNode() throws MaltChainedException { 1030 if (branches != null && leafModel != null) { 1031 1032 leafModel.terminate(); 1033 1034 leafModel = null; 1035 1036 } else 1037 throw new MaltChainedException( 1038 "Can't set a node that have aleready been set to a branch node."); 1039 1040 } 1041 1042 1043 @Override 1044 public void noMoreInstances() throws MaltChainedException { 1045 1046 if (leafModel == null) 1047 throw new GuideException( 1048 "The model in tree node is null in a state where it is not allowed"); 1049 1050 leafModel.noMoreInstances(); 1051 1052 if (divideFeatures.size() == 0) 1053 setIsLeafNode(); 1054 1055 if (branches != null) { 1056 1057 if(automaticSplit){ 1058 1059 divideFeatures = createGainRatioSplitList(divideFeatures); 1060 1061 divideFeatureIndexVector = new ArrayList<Integer>(); 1062 for (int i = 0; i < featureVector.size(); i++) { 1063 1064 if (featureVector.get(i).equals(divideFeatures.get(0))) { 1065 1066 divideFeatureIndexVector.add(i); 1067 } 1068 } 1069 1070 if (divideFeatureIndexVector.size() == 0) { 1071 throw new GuideException( 1072 "Could not match the given divide features to any of the available features."); 1073 } 1074 1075 } 1076 1077 FeatureFunction divideFeature = divideFeatures.getFirst(); 1078 1079 divideFeature.updateCardinality(); 1080 1081 leafModel.noMoreInstances(); 1082 1083 Map<Integer, Integer> divideFeatureIdToCountMap = leafModel 1084 .getMethod().createFeatureIdToCountMap( 1085 divideFeatureIndexVector); 1086 1087 int totalInOther = 0; 1088 1089 Set<Integer> featureIdsToCreateSeparateBranchesForSet = new HashSet<Integer>(); 1090 1091 List<Integer> removeFromDivideFeatureIdToCountMap = new LinkedList<Integer>(); 1092 1093 for (Entry<Integer, Integer> entry : divideFeatureIdToCountMap 1094 .entrySet()) 1095 if (entry.getValue() >= divideThreshold) { 1096 featureIdsToCreateSeparateBranchesForSet 1097 .add(entry.getKey()); 1098 } else { 1099 removeFromDivideFeatureIdToCountMap.add(entry.getKey()); 1100 totalInOther = totalInOther + entry.getValue(); 1101 } 1102 1103 for (int removeIndex : removeFromDivideFeatureIdToCountMap) 1104 divideFeatureIdToCountMap.remove(removeIndex); 1105 1106 boolean otherExists = false; 1107 1108 if (totalInOther > 0) 1109 otherExists = true; 1110 1111 if ((totalInOther < divideThreshold && featureIdsToCreateSeparateBranchesForSet 1112 .size() <= 1) 1113 || featureIdsToCreateSeparateBranchesForSet.size() == 0) { 1114 // Node enough instances, make this a leaf node 1115 setIsLeafNode(); 1116 } else { 1117 1118 // If total in other is less then divideThreshold then add the 1119 // smallest of the other parts to other 1120 if (otherExists && totalInOther < divideThreshold) { 1121 int smallestSoFar = Integer.MAX_VALUE; 1122 int smallestSoFarId = Integer.MAX_VALUE; 1123 for (Entry<Integer, Integer> entry : divideFeatureIdToCountMap 1124 .entrySet()) { 1125 if (entry.getValue() < smallestSoFar) { 1126 smallestSoFar = entry.getValue(); 1127 smallestSoFarId = entry.getKey(); 1128 } 1129 } 1130 1131 featureIdsToCreateSeparateBranchesForSet 1132 .remove(smallestSoFarId); 1133 } 1134 1135 // Create new files for all feature ids with count value greater 1136 // than divideThreshold and one for the 1137 // other branch 1138 leafModel.getMethod().divideByFeatureSet( 1139 featureIdsToCreateSeparateBranchesForSet, 1140 divideFeatureIndexVector, "" + OTHER_BRANCH_ID); 1141 1142 for (int id : featureIdsToCreateSeparateBranchesForSet) { 1143 DecisionTreeModel newBranch = new DecisionTreeModel(id, 1144 getSubFeatureVector(), this, 1145 createNextLevelDivideFeatures(), divideThreshold); 1146 branches.put(id, newBranch); 1147 1148 } 1149 if (otherExists) { 1150 DecisionTreeModel newBranch = new DecisionTreeModel( 1151 OTHER_BRANCH_ID, featureVector, this, 1152 new LinkedList<FeatureFunction>(), divideThreshold); 1153 branches.put(OTHER_BRANCH_ID, newBranch); 1154 1155 } 1156 1157 for (DecisionTreeModel b : branches.values()) 1158 b.noMoreInstances(); 1159 1160 } 1161 1162 } 1163 1164 } 1165 1166 @Override 1167 public void terminate() throws MaltChainedException { 1168 if (branches != null) { 1169 for (DecisionTreeModel branch : branches.values()) { 1170 branch.terminate(); 1171 } 1172 branches = null; 1173 } 1174 if (leafModel != null) { 1175 leafModel.terminate(); 1176 leafModel = null; 1177 } 1178 1179 } 1180 1181 public void setParent(Model parent) { 1182 this.parent = parent; 1183 } 1184 1185 public Model getParent() { 1186 return parent; 1187 } 1188 1189 public void setFrequency(int frequency) { 1190 this.frequency = frequency; 1191 } 1192 1193 public int getFrequency() { 1194 return frequency; 1195 } 1196 1197 public int getModelIndex() { 1198 return modelIndex; 1199 } 1200 1201 1202 private LinkedList<FeatureFunction> createGainRatioSplitList(LinkedList<FeatureFunction> divideFeatures) { 1203 1204 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) { 1205 1206 getGuide().getConfiguration().getConfigLogger().info( 1207 "Start calculating gain ratio for all posible divide features"); 1208 } 1209 1210 //Calculate the root entropy 1211 1212 double total = 0; 1213 1214 for(int count: classIdToCountMap.values()){ 1215 double fraction = ((double)count) / getFrequency(); 1216 total = total + fraction*log2(fraction); 1217 } 1218 1219 double rootEntropy = -total; 1220 1221 1222 class FeatureFunctionInformationGainPair implements Comparable<FeatureFunctionInformationGainPair>{ 1223 double informationGain; 1224 FeatureFunction featureFunction; 1225 double splitInfo; 1226 1227 public FeatureFunctionInformationGainPair( 1228 FeatureFunction featureFunction) { 1229 super(); 1230 this.featureFunction = featureFunction; 1231 } 1232 1233 public double getGainRatio(){ 1234 return informationGain/splitInfo; 1235 } 1236 1237 @Override 1238 public int compareTo(FeatureFunctionInformationGainPair o) { 1239 1240 int result = 0; 1241 1242 if((this.getGainRatio() - o.getGainRatio()) <0) 1243 result = -1; 1244 else if ((this.getGainRatio() - o.getGainRatio()) >0) 1245 result = 1; 1246 1247 return result; 1248 } 1249 } 1250 1251 ArrayList<FeatureFunctionInformationGainPair> gainRatioList = new ArrayList<FeatureFunctionInformationGainPair>(); 1252 1253 for(FeatureFunction f: divideFeatures) 1254 gainRatioList.add(new FeatureFunctionInformationGainPair(f)); 1255 1256 //For all divide features calculate the gain ratio 1257 1258 for(FeatureFunctionInformationGainPair p : gainRatioList){ 1259 1260 HashMap<Integer, Integer> featureIdToCountMapTmp = featureIdToCountMap.get(p.featureFunction); 1261 1262 HashMap<Integer, HashMap<Integer, Integer>> featureIdToClassIdToCountMapTmp = featureIdToClassIdToCountMap.get(p.featureFunction); 1263 1264 double sum = 0; 1265 1266 for(Entry<Integer, Integer> entry:featureIdToCountMapTmp.entrySet()){ 1267 int featureId = entry.getKey(); 1268 int numberOfElementsWithFeatureId = entry.getValue(); 1269 HashMap<Integer, Integer> classIdToCountMapTmp = featureIdToClassIdToCountMapTmp.get(featureId); 1270 1271 double sumImpurityMesure = 0; 1272 int totalElementsWithIdAndClass = 0; 1273 for(int elementsWithIdAndClass : classIdToCountMapTmp.values()){ 1274 1275 double fractionOfInstancesBelongingToClass = ((double)elementsWithIdAndClass)/numberOfElementsWithFeatureId; 1276 1277 totalElementsWithIdAndClass = totalElementsWithIdAndClass + elementsWithIdAndClass; 1278 1279 sumImpurityMesure= sumImpurityMesure+fractionOfInstancesBelongingToClass*log2(fractionOfInstancesBelongingToClass); 1280 1281 } 1282 1283 double impurityMesure = -sumImpurityMesure; 1284 1285 sum = sum + (((double)numberOfElementsWithFeatureId)/getFrequency())*impurityMesure; 1286 1287 } 1288 p.informationGain = rootEntropy - sum; 1289 1290 //Calculate split info 1291 1292 double splitInfoTotal = 0; 1293 1294 for(int nrOfElementsWithFeatureId:featureIdToCountMapTmp.values()){ 1295 double fractionOfTotal = ((double)nrOfElementsWithFeatureId)/getFrequency(); 1296 splitInfoTotal = splitInfoTotal + fractionOfTotal*log2(fractionOfTotal); 1297 } 1298 p.splitInfo= splitInfoTotal; 1299 1300 1301 } 1302 Collections.sort(gainRatioList); 1303 1304 1305 1306 //Log the result if info is enabled 1307 if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) { 1308 1309 getGuide().getConfiguration().getConfigLogger().info( 1310 "Gain ratio calculation finished the result follows:\n"); 1311 getGuide().getConfiguration().getConfigLogger().info( 1312 "Divide Feature\tGain Ratio\tInformation Gain\tSplit Info\n"); 1313 1314 for(FeatureFunctionInformationGainPair p :gainRatioList) 1315 getGuide().getConfiguration().getConfigLogger().info( 1316 p.featureFunction + "\t" + p.getGainRatio() + "\t" + p.informationGain + "\t" + p.splitInfo +"\n"); 1317 } 1318 1319 LinkedList<FeatureFunction> divideFeaturesNew = new LinkedList<FeatureFunction>(); 1320 1321 for(FeatureFunctionInformationGainPair p :gainRatioList) 1322 divideFeaturesNew.add(p.featureFunction); 1323 1324 1325 return divideFeaturesNew; 1326 1327 } 1328 1329 }