それはフル機能のソリューションではありませんが、あなたはこのような何かを始めることができます:
import org.apache.spark.ml.{UnaryTransformer}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
class NGramTokenizer(override val uid: String)
extends UnaryTransformer[String, Seq[String], NGramTokenizer] {
def this() = this(Identifiable.randomUID("ngramtokenizer"))
override protected def createTransformFunc: String => Seq[String] = {
getFeatures _
}
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType)
}
override protected def outputDataType: DataType = {
new ArrayType(StringType, true)
}
}
クイックチェック:あなたも何かにそれを一般化しようとすることができます
val df = Seq((1L, "abcdef"), (2L, "foobar")).toDF("k", "v")
val transformer = new NGramTokenizer().setInputCol("v").setOutputCol("vs")
transformer.transform(df).show
// +---+------+------------------+
// | k| v| vs|
// +---+------+------------------+
// | 1|abcdef|[f, ef, def, cdef]|
// | 2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+
これは:
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
import scala.reflect.runtime.universe._
class UnaryUDFTransformer[T : TypeTag, U : TypeTag](
override val uid: String,
f: T => U
) extends UnaryTransformer[T, U, UnaryUDFTransformer[T, U]] {
override protected def createTransformFunc: T => U = f
override protected def validateInputType(inputType: DataType): Unit =
require(inputType == schemaFor[T].dataType)
override protected def outputDataType: DataType = schemaFor[U].dataType
}
val transformer = new UnaryUDFTransformer("featurize", getFeatures)
.setInputCol("v")
.setOutputCol("vs")
ラップされた関数ではなくUDFを使用する場合は、Transformer
を直接拡張し、transform
メソッドをオーバーライドする必要があります。残念ながら、有用なクラスの大部分はプライベートですので、ややこしいことがあります。
また、あなたがUDFを登録することができます。
spark.udf.register("getFeatures", getFeatures _)
とSQLTransformer
import org.apache.spark.ml.feature.SQLTransformer
val transformer = new SQLTransformer()
.setStatement("SELECT *, getFeatures(v) AS vs FROM __THIS__")
transformer.transform(df).show
// +---+------+------------------+
// | k| v| vs|
// +---+------+------------------+
// | 1|abcdef|[f, ef, def, cdef]|
// | 2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+
私はモデルを保存しようとしましたが、 'Message:Pipeline書き込みは、書き込み可能を実装していないステージが含まれているため、このパイプラインでは失敗します。書き込み不可能なステージ:ngramtokenizer_f784079e2124タイプのクラス '私はWritableインターフェイスを実装する必要がありますか? –
これは私が以前に言及した悪い部分です。私が知る限り、最良のアプローチは 'DefaultParamsWritable'と' DefaultParamsReadable'を実装することですが、あなたのコードの少なくとも一部をMLパッケージに入れないと実行できません。あなたは 'MLWritable' /' MLReadable'でも試してみることができます。 – zero323