Создание фиктивных переменных в SparkR [duplicate]

Если во время проверки не появляется ошибка MySQL, убедитесь, что вы правильно создали таблицу базы данных. Это случилось со мной. Ищите любые нежелательные запятые или цитаты.

22
задан Joshua Taylor 10 November 2015 в 14:43
поделиться

6 ответов

Начиная с Spark 1.6 вы можете использовать функцию pivot на GroupedData и предоставить обобщенное выражение.

pivoted = (df
    .groupBy("ID", "Age")
    .pivot(
        "Country",
        ['US', 'UK', 'CA'])  # Optional list of levels
    .sum("Score"))  # alternatively you can use .agg(expr))
pivoted.show()

## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41|  3|  1|  2|
## |X02| 72|  4|  6|  7|
## +---+---+---+---+---+

Уровни могут быть опущены, но если они могут повысить производительность и служит в качестве внутреннего фильтра.

Этот метод по-прежнему относительно медленный, но, конечно, бит вручную передает данные вручную между JVM и Python.

14
ответ дан zero323 24 August 2018 в 22:26
поделиться

Прежде всего, это, вероятно, не очень хорошая идея, потому что вы не получаете никакой дополнительной информации, но вы привязываетесь к фиксированной схеме (то есть вам нужно знать, сколько стран вы ожидаете, и, конечно же, дополнительная страна означает изменение кода)

Сказав это, это проблема SQL, которая показана ниже. Но если вы полагаете, что это не слишком «программное обеспечение» (серьезно, я это слышал !!), то вы можете отсылать первое решение.

Решение 1:

def reshape(t):
    out = []
    out.append(t[0])
    out.append(t[1])
    for v in brc.value:
        if t[2] == v:
            out.append(t[3])
        else:
            out.append(0)
    return (out[0],out[1]),(out[2],out[3],out[4],out[5])
def cntryFilter(t):
    if t[2] in brc.value:
        return t
    else:
        pass

def addtup(t1,t2):
    j=()
    for k,v in enumerate(t1):
        j=j+(t1[k]+t2[k],)
    return j

def seq(tIntrm,tNext):
    return addtup(tIntrm,tNext)

def comb(tP,tF):
    return addtup(tP,tF)


countries = ['CA', 'UK', 'US', 'XX']
brc = sc.broadcast(countries)
reshaped = calls.filter(cntryFilter).map(reshape)
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
for i in pivot.collect():
    print i

Теперь решение 2: Конечно, лучше, поскольку SQL является правильным инструментом для этого

callRow = calls.map(lambda t:   

Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
callsDF = ssc.createDataFrame(callRow)
callsDF.printSchema()
callsDF.registerTempTable("calls")
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
                    from (select userid,age,\
                                  case when country='CA' then nbrCalls else 0 end ca,\
                                  case when country='UK' then nbrCalls else 0 end uk,\
                                  case when country='US' then nbrCalls else 0 end us,\
                                  case when country='XX' then nbrCalls else 0 end xx \
                             from calls) x \
                     group by userid,age")
res.show()

данных:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
 calls = sc.parallelize(data,1)
countries = ['CA', 'UK', 'US', 'XX']

Результат:

Из 1-го решения

(('X02', 72), (7, 6, 4, 8)) 
(('X01', 41), (2, 1, 3, 0))

Из второго решения:

root  |-- age: long (nullable = true)  
      |-- country: string (nullable = true)  
      |-- nbrCalls: long (nullable = true)  
      |-- userid: string (nullable = true)

userid age ca uk us xx 
 X02    72  7  6  4  8  
 X01    41  2  1  3  0

Пожалуйста, дайте мне знать, если это работает, или нет:)

Best Ayan

7
ответ дан ayan guha 24 August 2018 в 22:26
поделиться

Итак, сначала я должен был внести эту коррекцию в ваш RDD (который соответствует вашему фактическому выходу):

rdd = sc.parallelize([('X01',41,'US',3),
                      ('X01',41,'UK',1),
                      ('X01',41,'CA',2),
                      ('X02',72,'US',4),
                      ('X02',72,'UK',6),
                      ('X02',72,'CA',7),
                      ('X02',72,'XX',8)])

Как только я сделал эту коррекцию, это сделало трюк:

df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age")
.join(
    df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"),
    $"ID" === $"usID" and $"C1" === "US"
)
.join(
    df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"),
    $"ID" === $"ukID" and $"C2" === "UK"
)
.join(
    df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), 
    $"ID" === $"caID" and $"C3" === "CA"
)
.select($"ID",$"Age",$"US",$"UK",$"CA")

Не так элегантно, как ваш стержень, конечно.

1
ответ дан David Griffin 24 August 2018 в 22:26
поделиться

Вот родной подход Spark, который не затрудняет имена столбцов. Он основан на aggregateByKey и использует словарь для сбора столбцов, которые отображаются для каждого ключа. Затем мы собираем все имена столбцов, чтобы создать окончательный файл данных. [Предварительная версия использовала jsonRDD после испускания словаря для каждой записи, но это более эффективно.] Ограничение на конкретный список столбцов или исключение таких, как XX, было бы легкой модификацией.

Производительность кажется хорошим даже на довольно больших столах. Я использую вариацию, которая подсчитывает количество раз, каждое из которых имеет переменное число событий для каждого идентификатора, генерируя один столбец для каждого типа события. Код в основном тот же, за исключением того, что для подсчета вхождений используется коллекция.Counter вместо dict в seqFn.

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    return u

def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    df
    .rdd
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)
columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)
result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c) for c in columns]),
    schema=StructType(
        [StructField('ID', StringType())] + 
        [StructField(c, IntegerType()) for c in columns]
    )
)
result.show()

Производит:

ID  CA UK US XX  
X02 7  6  4  8   
X01 2  1  3  null
4
ответ дан patricksurry 24 August 2018 в 22:26
поделиться

Просто некоторые комментарии к очень полезному ответу patricksurry:

  • отсутствует столбец Age, поэтому просто добавьте u ["Age"] = v.Age к функции seqPivot
  • оказалось, что обе петли над элементами столбцов давали элементы в другом порядке. Значения столбцов были правильными, но не именами их. Чтобы избежать этого, просто закажите список столбцов.

Вот немного модифицированный код:

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

# u is a dictionarie
# v is a Row
def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    # In the original posting the Age column was not specified
    u["Age"] = v.Age
    return u

# u1
# u2
def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    rdd
    .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2],  Score=row[3]))
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)

columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)

columns_ord = sorted(columns)

result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]),
        schema=StructType(
            [StructField('ID', StringType())] + 
            [StructField(c, IntegerType()) for c in columns_ord]
        )
    )

print result.show()

Наконец, выход должен быть

+---+---+---+---+---+----+
| ID|Age| CA| UK| US|  XX|
+---+---+---+---+---+----+
|X02| 72|  7|  6|  4|   8|
|X01| 41|  2|  1|  3|null|
+---+---+---+---+---+----+
1
ответ дан rolpat 24 August 2018 в 22:26
поделиться

В PIVOT есть JIRA, чтобы сделать это изначально, без огромного оператора CASE для каждого значения:

https://issues.apache.org/jira/browse/HIVE -3776

Прошу проголосовать за JIRA, чтобы она была реализована раньше. Как только в Hive SQL, Spark обычно не слишком сильно отстает, и в конечном итоге он будет реализован и в Spark.

0
ответ дан Tagar 24 August 2018 в 22:26
поделиться
Другие вопросы по тегам:

Похожие вопросы: