import pyspark
# importing sparksession from pyspark.sql module
from pyspark.sql import SparkSession
def split2df(prod_df, ratio=0.8):
# Calculate count of each dataframe rows
length = int(prod_df.count() * ratio)
# Create a copy of original dataframe
copy_df = prod_df
# Iterate for each dataframe
temp_df = copy_df.limit(length)
# Truncate the `copy_df` to remove
# the contents fetched for `temp_df`
copy_df = copy_df.subtract(temp_df)
length2 = prod_df.count() - length
temp_df2 = copy_df.limit(length2)
copy_df2 = copy_df.subtract(temp_df2)
return temp_df, temp_df2
# creating sparksession and giving an app name
spark = SparkSession.builder.appName('sparkdf').getOrCreate()
# Column names for the dataframe
columns = ["Brand", "Product"]
# Row data for the dataframe
data = [
("HP", "Laptop"),
("Lenovo", "Mouse"),
("Dell", "Keyboard"),
("Samsung", "Monitor"),
("MSI", "Graphics Card"),
("Asus", "Motherboard"),
("Gigabyte", "Motherboard"),
("Zebronics", "Cabinet"),
("Adata", "RAM"),
("Transcend", "SSD"),
("Kingston", "HDD"),
("Toshiba", "DVD Writer")
]
# Create the dataframe using the above values
prod_df = spark.createDataFrame(data=data,
schema=columns)
# View the dataframe
prod_df.show()
df1, df2 = split2df(prod_df)
df1.show(truncate=False)
df2.show(truncate=False)
分割结果:
+---------+-------------+
| Brand| Product|
+---------+-------------+
| HP| Laptop|
| Lenovo| Mouse|
| Dell| Keyboard|
| Samsung| Monitor|
| MSI|Graphics Card|
| Asus| Motherboard|
| Gigabyte| Motherboard|
|Zebronics| Cabinet|
| Adata| RAM|
|Transcend| SSD|
| Kingston| HDD|
| Toshiba| DVD Writer|
+---------+-------------++---------+-------------+
|Brand |Product |
+---------+-------------+
|HP |Laptop |
|Lenovo |Mouse |
|Dell |Keyboard |
|Samsung |Monitor |
|MSI |Graphics Card|
|Asus |Motherboard |
|Gigabyte |Motherboard |
|Zebronics|Cabinet |
|Adata |RAM |
+---------+-------------++---------+----------+
|Brand |Product |
+---------+----------+
|Transcend|SSD |
|Toshiba |DVD Writer|
|Kingston |HDD |
+---------+----------+
参考:
https://www.geeksforgeeks.org/pyspark-split-dataframe-into-equal-number-of-rows/