Spark ML
์คํํฌ์ ์ฌ๋ฌ ์ปดํฌ๋ํธ
โข
Spark SQL
โข
Spark Streaming
โข
MLlib
โข
GraphX ๋ฑ๋ฑ..
MLlib(Machine Learing Library)์ ML์ ์ฝ๊ณ ํ์ฅ์ฑ ์๊ฒ ์ ์ฉํ๊ธฐ ์ํด,
๋จธ์ ๋ฌ๋ ํ์ดํ๋ผ์ธ ๊ฐ๋ฐ์ ์ฝ๊ฒ ํ๊ธฐ ์ํด ๋ง๋ค์ด์ก๋ค.
Machine Learning ์ด๋?
โข
๋ฐ์ดํฐ๋ฅผ ์ด์ฉํด ์ฝ๋ฉ์ ํ๋ ์ผ
โข
์ต์ ํ์ ๊ฐ์ ๋ฐฉ๋ฒ์ ํตํด ํจํด์ ์ฐพ๋ ์ผ
๋จธ์ ๋ฌ๋ ํ์ดํ๋ผ์ธ ๊ตฌ์ฑ
๋ฐ์ดํฐ ๋ก๋ฉ โ ์ ์ฒ๋ฆฌ โ ํ์ต โ ๋ชจ๋ธ ํ๊ฐ
MLlib์ DataFrame์์์ ๋์
์์ง RDD API๊ฐ ์์ง๋ง maintenance mode์ด๋ฉฐ, ์๋ก์ด API๋ ๊ฐ๋ฐ ๋๊น.
DataFrame์ ์ฐ๋ MLlib API๋ฅผ Spark ML์ด๋ผ๊ณ ๋ ๋ถ๋ฆ.
SparkML์ ์ฌ์ฉํ๋ ์ด์
๋ค๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋นํด ์คํํฌ๋ ๋์ค์ ์ผ๋ก ์ฌ์ฉ๋๋ ๋ช๋ช ์๊ณ ๋ฆฌ์ฆ๋ง ๊ตฌํ๋์ด ์๋ค.
โ ์๋กญ๊ฑฐ๋ ํซํ ๋ชจ๋ธ์ด ๋์๋ ์คํํฌ์์ ์ฐ๋ ค๋ฉด ๋ค๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ณด๋ค๋ ์กฐ๊ธ ๋ ๊ธฐ๋ค๋ ค์ผ ํ๋ค๋ ๋จ์ ์ ์กด์ฌ.
๊ทธ๋ผ SparkML์ ์ ์ฐ๋๊ฑธ๊น? ๋๋์ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋๋ฐ ๋งค์ฐ ์ ํฉํ๊ธฐ ๋๋ฌธ.
์คํํฌ๋ ๋ฐ์ดํฐ๋ฅผ ์ธ๋ฉ๋ชจ๋ฆฌ ์์์ ์ฒ๋ฆฌ. ๋ฐ์ดํฐ๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ์ฌ๋ ค์ ์ฒ๋ฆฌํ๋ฉด ๋์คํฌ๋ฅผ ์ฌ์ฉํ๋ ๋งต๋ฆฌ๋์ค๋ ๋จธํ์๋ณด๋ค 10๋ฐฐ์์ 100๋ฐฐ๊น์ง ๋น ๋ฅธ ๊ฒฐ๊ณผ๋ฅผ ์ป์ด๋ผ ์ ์๋ค.
ํ์ต์ ํ์ํ ์ ์ฒ๋ฆฌ๋ฅผ ์คํํฌ๋ก ์งํํ๊ณ ๋ชจ๋ธ๋ง์ ํ
์ํ๋ก์ฐ์ ๊ฐ์ ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์งํํ๊ฑฐ๋, ์คํํฌ ์ง์ ๋ชจ๋ธ๋ก ์ถฉ๋ถํ ํ๋ก์ ํธ๋ผ๋ฉด ๋ชจ๋ธ๋ง๊น์ง ์คํํฌ๋ก ๋ง๋ฌด๋ฆฌํ์ฌ ์์
์ ์๋๋ฅผ ๋์ผ ์ ์๋ค.
MLlib ์ปดํฌ๋ํธ
โข
์๊ณ ๋ฆฌ์ฆ
โฆ
Classification
โฆ
Regression
โฆ
Clustering
โฆ
Recommendation
โข
ํ์ดํ๋ผ์ธ
โฆ
Training
โฆ
Evaluating
โฆ
Tuning
โฆ
Persistence
โข
Feature Engineering
โฆ
Extraction
โฆ
Transformation
โข
Utils
โฆ
Linear algebra
โฆ
Statistics
์ฃผ์ ์ปดํฌ๋ํธ ์๊ฐ
โข
DataFrame
โฆ
MLํ์ดํ๋ผ์ธ์์๋ ๋ฐ์ดํฐํ๋ ์์ด ๊ธฐ๋ณธ ํฌ๋งท์ด๋ฉฐ, ํ
์คํธ์
์ ๋ก๋ฉํ๊ธฐ ์ํด ๊ธฐ๋ณธ์ ์ผ๋ก csv, JSON, Parquet, JDBC๋ฅผ ์ง์. ML ํ์ดํ๋ผ์ธ์์ ๋ค์ 2๊ฐ์ง์ ์๋ก์ด ๋ฐ์ดํฐ์์ค๋ฅผ ์ถ๊ฐ ์ง์ํจ.
โฆ
์ด๋ฏธ์ง ๋ฐ์ดํฐ์์ค
โช
jpeg, png ๋ฑ์ ์ด๋ฏธ์ง๋ค์ ์ง์ ๋ ๋๋ ํ ๋ฆฌ์์ ๋ก๋
โฆ
LIBSVM ๋ฐ์ดํฐ์์ค
โช
label๊ณผ features ๋ ๊ฐ์ ์ปฌ๋ผ์ผ๋ก ๊ตฌ์ฑ๋๋ ๋จธ์ ๋ฌ๋ ํธ๋ ์ด๋ ํฌ๋งท
โช
features ์ปฌ๋ผ์ ๋ฒกํฐ ํํ์ ๊ตฌ์กฐ
โข
Transformer
โฆ
ํผ์ณ ๋ณํ๊ณผ ํ์ต๋ ๋ชจ๋ธ์ ์ถ์ํ
โฆ
๋ชจ๋ Transformer๋ transform() ํจ์๋ฅผ ๊ฐ์ง๊ณ ์์.
โฆ
๋ฐ์ดํฐ๋ฅผ ํ์ต์ด ๊ฐ๋ฅํ ํฌ๋งท์ผ๋ก ๋ฐ๊ฟ.
โฆ
DF๋ฅผ ๋ฐ์ ์๋ก์ด DF๋ฅผ ๋ง๋๋๋ฐ, ๋ณดํต ํ๋ ์ด์์ column์ ๋ํ๊ฒ ๋๋ค.
โฆ
ex) Data Normalization, Tokenization, ์นดํ
๊ณ ๋ฆฌ์ปฌ ๋ฐ์ดํฐ๋ฅผ ์ซ์๋ก (one-hot encoding)
โข
Estimator
โฆ
๋ชจ๋ธ์ ํ์ต ๊ณผ์ ์ ์ถ์ํ
โฆ
๋ชจ๋ Estimator๋ fit() ํจ์๋ฅผ ๊ฐ์ง๊ณ ์์.
โฆ
fit()์ DataFrame์ ๋ฐ์ Model์ ๋ฐํ
โฆ
๋ชจ๋ธ์ ํ๋์ Transformer
ex)
lr = LinearRegression()
model = lr.fit(data)
โข
Evaluator
โฆ
metric์ ๊ธฐ๋ฐ์ผ๋ก ๋ชจ๋ธ์ ์ฑ๋ฅ์ ํ๊ฐ
ex) Root mean squared error (RMSE)
โฆ
๋ชจ๋ธ์ ์ฌ๋ฌ๊ฐ ๋ง๋ค์ด์, ์ฑ๋ฅ์ ํ๊ฐ ํ ๊ฐ์ฅ ์ข์ ๋ชจ๋ธ์ ๋ฝ๋ ๋ฐฉ์์ผ๋ก ๋ชจ๋ธ ํ๋์ ์๋ํ ๊ฐ๋ฅ.
ex) BinaryClassificationEvaluator, CrossValidator
โข
Pipeline
โฆ
ML์ ์ํฌํ๋ก์ฐ
โฆ
์ฌ๋ฌ stage๋ฅผ ๋ด๊ณ ์์ Pipeline(stages=).
โฆ
์ ์ฅ๋ ์ ์์ (persist)
โฆ
Transformer โ Estimator โ Evaluator โ Model
ML Pipeline์ ๊ฒฐ๊ตญ ํ๋ ์ด์์ Transformer์ Estimator๊ฐ ์ฐ๊ฒฐ๋ ๋ชจ๋ธ๋ง ์ํฌํ๋ก์ฐ๋ก, ์
๋ ฅ์ ๋ฐ์ดํฐํ๋ ์์ด๊ณ ์ถ๋ ฅ์ ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ธ ๊ฒ์ด๋ค.
ML Pipeline ๊ทธ ์์ฒด๋ Estimator์ด๋ฏ๋ก ์คํ์ fitํจ์์ ํธ์ถ๋ก ์์ํ ์ ์์ผ๋ฉฐ, ์ ์ฅํ๋ค๊ฐ ๋ค์ ๋ก๋ฉํ๋ ๊ฒ์ด ๊ฐ๋ฅํด ํ๋ฒ ํ์ดํ๋ผ์ธ์ ๋ง๋ค์ด๋๋ฉด ๋ฐ๋ณต์ ์ธ ๋ชจ๋ธ ๋น๋ฉ์ด ์ฌ์์ง๋ค.
โข
Parameter
โฆ
Transformer์ Estimator์ ๊ณตํต API๋ก ๋ค์ํ ์ธ์๋ฅผ ์ ์ฉํด์ค.
โฆ
Param(ํ๋์ ์ด๋ฆ๊ณผ ๊ฐ)๊ณผ ParamMap(Param ๋ฆฌ์คํธ) ๋ ์ข
๋ฅ์ ํ๋ผ๋ฏธํฐ๊ฐ ์กด์ฌ.
โฆ
ํ๋ผ๋ฏธํฐ๋ fit (Estimator) ํน์ transform (Transformer)์ ์ธ์๋ก ์ง์ ๊ฐ๋ฅ.
Spark ML ํผ์ณ๋ณํ
โข
Feature Transformer๊ฐ ํ๋ ์ผ
โฆ
๊ธฐ๋ณธ์ ์ผ๋ก ๋จธ์ ๋ฌ๋์์ ๋ชจ๋ ํผ์ณ ๊ฐ๋ค์ ์ซ์ ํ๋์ด์ด์ผ ํ๋ฏ๋กย ํ
์คํธ ํ๋(์นดํ
๊ณ ๋ฆฌ ๊ฐ๋ค)๋ฅผ ์ซ์ ํ๋๋ก ๋ณํ
โฆ
์ซ์ ํ๋๋ผ๊ณ ํด๋ ๊ฐ๋ฅํ ๊ฐ์ ๋ฒ์๋ฅผ ํน์ ๋ฒ์(0๋ถํฐ 1)๋ก ๋ณํํ๋ ํ์คํ๊ฐ ํ์, ์ด๋ฅผย ํผ์ณ ์ค์ผ์ผ๋ง(Scaling) ํน์ ์ ๊ทํ(Normalization)๋ผ๊ณ ํจ.
โข
Feature Extractor๊ฐ ํ๋ ์ผ
โฆ
๊ธฐ์กด ํผ์ณ์์ ์๋ก์ด ํผ์ณ๋ฅผ ์ถ์ถ
โฆ
TF-IDF, Word2Vec ๋ฑ
โฆ
ํ
์คํธ ๋ฐ์ดํฐ๋ฅผ ์ด๋ค ํํ๋ก ์ธ์ฝ๋ฉํ๋ ๊ฒ์ด ์ฌ๊ธฐ์ ํด๋น
โข
StringIndexer: ํ
์คํธ ์นดํ
๊ณ ๋ฆฌ๋ฅผ ์ซ์๋ก ๋ณํ
โฆ
Scikit-Learn์ sklearn.preprocessing ๋ชจ๋ ์๋ ์ฌ๋ฌ ์ธ์ฝ๋(OneHotEncoder, LabelEncoder, OrdinalEncoder ๋ฑ) ์กด์ฌ
โฆ
Spark MLlib์ ๊ฒฝ์ฐ pyspark.ml.feature ๋ชจ๋ ๋ฐ์ ๋ ๊ฐ์ ์ธ์ฝ๋ ์กด์ฌ
โฆ
StringIndexer, OneHotEncoder
โฆ
์ฌ์ฉ๋ฒ์ Indexer ๋ชจ๋ธ์ ๋ง๋ค๊ณ (fit), Indexter ๋ชจ๋ธ๋ก ๋ฐ์ดํฐํ๋ ์์ ๋ณํ(Transform)
โข
Scaler: ์ซ์ ํ๋ ๊ฐ์ ๋ฒ์๋ฅผ 0๊ณผ 1์ฌ์ด๋ก ํ์คํ
โฆ
pyspark.ml.feature ๋ชจ๋ ๋ฐ์ ๋ ๊ฐ์ ์ค์ผ์ผ๋ฌ ์กด์ฌ
โฆ
StandardScaler: ๊ฐ ๊ฐ์์ ํ๊ท ์ ๋นผ๊ณ ์ด๋ฅผ ํ์คํธ์ฐจ๋ก ๋๋. ๊ฐ์ ๋ถํฌ๊ฐ ์ ๊ท๋ถํฌ๋ฅผ ๋ฐ๋ฅด๋ ๊ฒฝ์ฐ ์ฌ์ฉ
โฆ
MinMaxScaler: ๋ชจ๋ ๊ฐ์ 0๊ณผ 1์ฌ์ด๋ก ์ค์ผ์ผ๋ง. ๊ฐ ๊ฐ์์ ์ต์๊ฐ์ ๋นผ๊ณ (์ต๋๊ฐ-์ต์๊ฐ)์ผ๋ก ๋๋
โข
Imputer: ๊ฐ์ด ์๋ ํ๋ ์ฑ์ฐ๊ธฐ
โฆ
๊ฐ์ด ์กด์ฌํ์ง ์๋ ๋ ์ฝ๋๋ค์ด ์กด์ฌํ๋ ํ๋๋ค์ ๊ฒฝ์ฐ ๊ธฐ๋ณธ๊ฐ(ํ๊ท ๊ฐ, ์ค์๊ฐ ๋ฑ)์ ์ ํด ์ฑ์
์ถ์ฒ ๋ชจ๋ธ
์ ์ ๋ณ ์ํ ์ถ์ฒ ํ์ดํ๋ผ์ธ
โข
ALS: Alternating Least Squares
์ถ์ฒ ์๊ณ ๋ฆฌ์ฆ ์ค ํ๋๋ก, ๊ต๋ ์ต์ ์ ๊ณฑ๋ฒ์ด๋ผ๊ณ ๋ ๋ถ๋ฅธ๋ค.
ํ ์ ์ ๊ฐ ๋ณผ ์ ์๋ ์ํ๊ฐ ๋๋ฌด ๋ง๊ธฐ์, ๋ชป๋ณธ ์ํ๋ค์ ํ์ ์ ์์ธกํ๊ณ ๊ฐ์ฅ ๋์ ์ ์๋ถํฐ ์ ์ ์๊ฒ ์ ๋ฌํ๋ ๊ฒ์ด ๋ฐ๋ก ์ถ์ฒ์ด๋ค.
ALS๋ ๋ ํ๋ ฌ ์ค ํ๋๋ฅผ ๊ณ ์ ์ํค๊ณ ๋ค๋ฅธ ํ๋์ ํ๋ ฌ์ ์์ฐจ์ ์ผ๋ก ๋ฐ๋ณตํ๋ฉด์ ์ต์ ํํ๋ ๋ฐฉ์์ด๋ค.
์์ธก ๋ชจ๋ธ
๊ฑฐ๋ฆฌ๋ณ ํ์๋น ์์ธกํ๊ธฐ
โข
Linear Regression (์ ํ ํ๊ท)
์ข
์๋ณ์ y์ ํ ๊ฐ ์ด์์ ๋
๋ฆฝ๋ณ์ x์ ๋ํ ์ ํ ์๊ด ๊ด๊ณ๋ฅผ ๋ชจ๋ธ๋งํ๋ ํ๊ท ๋ถ์ ๋ฐฉ๋ฒ.
์์ ๊ฐ์ด ๋ฐ์ดํฐ๊ฐ ๋ถํฌ๋์ด์์ ๋, ๋ฐ์ดํฐ์ ๋ถํฌ๊ฐ ๊ฐ์ฅ ์ ๋ง๋ ์ ์ ๊ธ๋ ๊ฒ(์ต์ ํ).
โข
RMSE(Root Mean Squared Error)
์์ธก ๋ชจ๋ธ์์ ์์ธกํ ๊ฐ๊ณผ ์ค์ ๊ฐ ์ฌ์ด์ ํ๊ท ์ฐจ์ด๋ฅผ ์ธก์ ํ๋ค.
์์ธก ๋ชจ๋ธ์ด ๋ชฉํ ๊ฐ(์ ํ๋)์ ์ผ๋ง๋ ์ ์์ธกํ ์ ์๋์ง ์ถ์ ํ๋ค.
์ค์ต ์ฝ๋
์ํ ์ถ์ฒ ํ์ดํ๋ผ์ธ
git clone https://github.com/Y-gw/boaz-sparkML.git
Shell
๋ณต์ฌ
ํ์ผ ๋ง๋ค๊ธฐ
docker pull jupyter/all-spark-notebook
docker run -p 8888:8888 -e JUPYTER_ENABLE_LAB=yes -v {LOCAL_PATH}:/home/jovyan --name jupyter jupyter/all-spark-notebook
Shell
๋ณต์ฌ
token ๋ณต์ฌ
http://localhost:8888/
Shell
๋ณต์ฌ
์ ์
โข
DF ๊ตฌ์กฐ
+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
| 1| 296| 5.0|1147880044|
| 1| 306| 3.5|1147868817|
| 1| 307| 5.0|1147868828|
| 1| 665| 5.0|1147878820|
| 1| 899| 3.5|1147868510|
| 1| 1088| 4.0|1147868495|
| 1| 1175| 3.5|1147868826|
| 1| 1217| 3.5|1147878326|
| 1| 1237| 5.0|1147868839|
| 1| 1250| 4.0|1147868414|
| 1| 1260| 3.5|1147877857|
| 1| 1653| 4.0|1147868097|
| 1| 2011| 2.5|1147868079|
| 1| 2012| 2.5|1147868068|
| 1| 2068| 2.5|1147869044|
| 1| 2161| 3.5|1147868609|
| 1| 2351| 4.5|1147877957|
| 1| 2573| 4.0|1147878923|
| 1| 2632| 5.0|1147878248|
| 1| 2692| 5.0|1147869100|
+------+-------+------+----------+
only showing top 20 rows
Plain Text
๋ณต์ฌ
โข
ML ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ถ๋ฌ์์ ๋ชจ๋ธ ๋ง๋ค๊ธฐ
from pyspark.ml.recommendation import ALS # ALS ์๊ณ ๋ฆฌ์ฆ ๋ถ๋ฌ์ค๊ธฐ
als = ALS(
maxIter=5,
regParam=0.1,
userCol="userId",
itemCol="movieId",
ratingCol="rating",
coldStartStrategy="drop"
)
model = als.fit(train_df) # fit ์ปดํฌ๋ํธ๋ฅผ ํ์ฉํ train
predictions = model.transform(test_df) # transform ์ปดํฌ๋ํธ๋ฅผ ํ์ฉํ model test
predictions.show()
Python
๋ณต์ฌ
+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
| 76| 1342| 3.5| 2.9047337|
| 85| 1088| 2.0| 3.7284317|
| 132| 1238| 5.0| 3.2149928|
| 132| 1580| 3.0| 3.2497048|
| 137| 1645| 3.0| 3.167203|
| 230| 833| 3.0| 2.5753236|
| 230| 1088| 4.0| 3.115355|
| 243| 1580| 3.0| 2.5723686|
| 319| 1238| 5.0| 3.8150952|
| 333| 1088| 5.0| 4.05824|
| 368| 1580| 3.5| 3.603326|
| 368| 3175| 5.0| 3.5701354|
| 409| 8638| 5.0| 3.9008398|
| 458| 1580| 3.5| 3.1976578|
| 472| 3918| 3.0| 2.3450446|
| 548| 5803| 2.5| 2.6988087|
| 548| 36525| 3.5| 3.169841|
| 548| 82529| 3.0| 3.22782|
| 587| 6466| 4.0| 3.3879008|
| 597| 3997| 1.0| 1.9163384|
+------+-------+------+----------+
only showing top 20 rows
Plain Text
๋ณต์ฌ
โข
๋ชจ๋ธ ํ๊ฐ
from pyspark.ml.evaluation import RegressionEvaluator
evaluator = RegressionEvaluator(metricName="rmse", labelCol='rating', predictionCol='prediction')
rmse = evaluator.evaluate(predictions)
print(rmse)
>> ex) 0.8184303257919787
Python
๋ณต์ฌ
โข
์ถ์ฒ
model.recommendForAllUsers(3).show() # ์ ์ ๋ณ Top3๊ฐ์ ์์ดํ
์ถ์ฒ
Python
๋ณต์ฌ
+------+--------------------+
|userId| recommendations|
+------+--------------------+
| 12|[{151989, 6.29235...|
| 22|[{199187, 7.83498...|
| 26|[{151989, 5.92996...|
| 27|[{203086, 6.41190...|
| 28|[{151989, 8.20413...|
| 31|[{151989, 4.24829...|
| 34|[{151989, 6.02863...|
| 44|[{151989, 7.49052...|
| 47|[{151989, 5.90802...|
| 53|[{151989, 7.39449...|
| 65|[{205277, 6.63911...|
| 76|[{151989, 6.86239...|
| 78|[{151989, 7.89406...|
| 81|[{151989, 4.36021...|
| 85|[{151989, 5.68050...|
| 91|[{203086, 5.94486...|
| 93|[{151989, 6.51991...|
| 101|[{151989, 5.67113...|
| 103|[{151989, 6.55442...|
| 108|[{151989, 6.03695...|
+------+--------------------+
only showing top 20 rows
Plain Text
๋ณต์ฌ
โข
์ ์ ๋ณ ์ถ์ฒ api๋ฅผ ์ํ ์ฝ๋
from pyspark.sql.types import IntegerType
user_list = [65, 78, 81]
users_df = spark.createDataFrame(user_list, IntegerType()).toDF('userId')
Python
๋ณต์ฌ
user_recs = model.recommendForUserSubset(users_df, 5)
movies_list = user_recs.collect()[0].recommendations
recs_df = spark.createDataFrame(movies_list)
Python
๋ณต์ฌ
ํ์๋น ์์ธกํ๊ธฐ
โข
๊ตฌ์กฐ ํ์ธ ๋ฐ ํ์ ์ปฌ๋ผ ์ถ์ถ
trips_df.createOrReplaceTempView("trips") # sql์์ ์ธ ์ ์๊ฒ ๋ณํ
Python
๋ณต์ฌ
query = """
SELECT
trip_distance, #์บ์คํ
ํ์ํ ์๋ ์์.
total_amount
FROM
trips
WHERE
total_amount < 5000
AND total_amount > 0
AND trip_distance > 0
AND trip_distance < 500
AND passenger_count < 4
AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
AND TO_DATE(tpep_pickup_datetime) < '2021-04-01'
"""
SQL
๋ณต์ฌ
+-------------+------------+
|trip_distance|total_amount|
+-------------+------------+
| 2.1| 11.8|
| 0.2| 4.3|
| 14.7| 51.95|
| 10.6| 36.35|
| 4.94| 24.36|
| 1.6| 14.15|
| 4.1| 17.3|
| 5.7| 21.8|
| 9.1| 28.8|
| 2.7| 18.95|
| 6.11| 24.3|
| 1.21| 10.79|
| 7.4| 33.92|
| 1.01| 10.3|
| 0.73| 12.09|
| 1.17| 12.36|
| 0.78| 9.96|
| 1.66| 12.3|
| 0.93| 9.3|
| 1.16| 11.84|
+-------------+------------+
only showing top 20 rows
Plain Text
๋ณต์ฌ
โข
feature column ์์ฑ์ ํตํ Train ๋ฐ์ดํฐ์
๊ตฌ์ฑ
from pyspark.ml.feature import VectorAssembler
vassembler = VectorAssembler(inputCols=["trip_distance"], outputCol="features")
vtrain_df = vassembler.transform(train_df)
Python
๋ณต์ฌ
โข
regression ๋ชจ๋ธ ์์ฑ
from pyspark.ml.regression import LinearRegression
lr = LinearRegression(
maxIter=50,
labelCol="total_amount",
featuresCol="features"
)
model = lr.fit(vtrain_df)
vtest_df = vassembler.transform(test_df)
prediction = model.transform(vtest_df)
prediction.show()
Python
๋ณต์ฌ
+-------------+------------+--------+-----------------+
|trip_distance|total_amount|features| prediction|
+-------------+------------+--------+-----------------+
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.3| [0.01]|8.291036440655487|
| 0.01| 3.8| [0.01]|8.291036440655487|
| 0.01| 3.8| [0.01]|8.291036440655487|
| 0.01| 3.8| [0.01]|8.291036440655487|
| 0.01| 3.8| [0.01]|8.291036440655487|
| 0.01| 3.8| [0.01]|8.291036440655487|
| 0.01| 3.8| [0.01]|8.291036440655487|
+-------------+------------+--------+-----------------+
only showing top 20 rows
Plain Text
๋ณต์ฌ
โข
๋ชจ๋ธ ํ๊ฐ
model.summary.rootMeanSquaredError
>> ex) 4.872759850891687
Python
๋ณต์ฌ
model.summary.r2 # total amount์ 82%๊ฐ trip_distance๋ก ์ค๋ช
์ด ๊ฐ๋ฅํ๋ค๋ ๋ง๊ณผ ๊ฐ์.
>> ex) 0.8237208415594777
Python
๋ณต์ฌ
โข
์ง์ ์
๋ ฅ๊ฐ ์ค์ ํด์ ์์ธก
from pyspark.sql.types import DoubleType
distance_list = [1.1, 5.5, 10.5, 30.0]
distance_df = spark.createDataFrame(distance_list, DoubleType()).toDF("trip_distance")
vdistance_df = vassembler.transform(distance_df)
model.transform(vdistance_df).show()
Python
๋ณต์ฌ
+-------------+--------+------------------+
|trip_distance|features| prediction|
+-------------+--------+------------------+
| 1.1| [1.1]|11.770960525327574|
| 5.5| [5.5]|25.818360500150682|
| 10.5| [10.5]|41.781315016995116|
| 30.0| [30.0]|104.03683763268842|
+-------------+--------+------------------+
Plain Text
๋ณต์ฌ