from sklearn.model_selection import train_test_split
X = train['vectors']
y = train['female']
X_array = np.array([np.array(v) for v in X])
y_array = np.array([i for i in y])
X_train, X_test, y_train, y_test = train_test_split(X_array, y_array, random_state=0)
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
n_jobs=-1, train_sizes=np.linspace(.1, 1.0, 5)):
'''Generate a simple plot of the test and training learning curve'''
plt.figure()
plt.title(title)
if ylim is not None:
plt.ylim(*ylim)
plt.xlabel('Training examples')
plt.ylabel('Score')
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
plt.grid()
plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color='r')
plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1,
color='g')
plt.plot(train_sizes, train_scores_mean, 'o-', color='r',
label='Training score')
plt.plot(train_sizes, test_scores_mean, 'o-', color='g',
label='Cross-validation score')
plt.legend(loc='best')
return plt
g = plot_learning_curve(gsLR.best_estimator_, 'LR mearning curves', X_train, y_train, cv=kfold)
g = plot_learning_curve(gsRF.best_estimator_, 'RF learning curves', X_train, y_train, cv=kfold)
g = plot_learning_curve(gsSGD.best_estimator_, 'SGD learning curves', X_train, y_train, cv=kfold)
Random Forestも良い気がしてきた。何とかトレーニングデータを増やしてもう少し確認したい。
最後に
全くトレーニングで使用していない未知のデータに対して、 Random Forest の予測モデルで判定してみたところ83%程の正解率だった。なろう小説のTop300はファンタジー系の小説が多いため、ファンタジー系の小説において、この正解率とみた方が良いかも知れない。ジャンルが違う小説の場合どのような結果になるか現時点では予測不能である。
from gensim.models.doc2vec import Doc2Vec
MODEL_DATA_PATH = 'drive/My Drive/Colab Notebooks/syosetu/doc2vec_100.model'
m = Doc2Vec.load(MODEL_DATA_PATH)
vectors_list = [m.docvecs[n] for n in range(len(m.docvecs))]
export JAVA_HOME="$HOME/jdk-11-latest"
export HADOOP_HOME="$HOME/hadoop-latest"
if [ -d "$HADOOP_HOME" ] ; then
export HADOOP_CONF_DIR="$HADOOP_HOME/etc/hadoop"
fi
OpenJDK Server VM warning: You have loaded library /home/xxx/hadoop-3.2.2/lib/native/libhadoop.so.1.0.0 which might have disabled stack guard. The VM will try to fix the stack guard now.
It's highly recommended that you fix the library with 'execstack -c <libfile>', or link it with '-z noexecstack'.